Return session name in runner response (#6112)

Makes it easier to then run evals w/o thinking about specifying a
session
This commit is contained in:
Zander Chase 2023-06-13 16:59:43 -07:00 committed by GitHub
parent e74733ab9e
commit b3b155d488
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 31 deletions

View File

@ -422,14 +422,14 @@ async def arun_on_dataset(
client will be created using the credentials in the environment. client will be created using the credentials in the environment.
Returns: Returns:
A dictionary mapping example ids to the model outputs. A dictionary containing the run's session name and the resulting model outputs.
""" """
client_ = client or LangChainPlusClient() client_ = client or LangChainPlusClient()
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name) session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
dataset = client_.read_dataset(dataset_name=dataset_name) dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id)) examples = client_.list_examples(dataset_id=str(dataset.id))
return await arun_on_examples( results = await arun_on_examples(
examples, examples,
llm_or_chain_factory, llm_or_chain_factory,
concurrency_level=concurrency_level, concurrency_level=concurrency_level,
@ -437,6 +437,10 @@ async def arun_on_dataset(
session_name=session_name, session_name=session_name,
verbose=verbose, verbose=verbose,
) )
return {
"session_name": session_name,
"results": results,
}
def run_on_dataset( def run_on_dataset(
@ -466,16 +470,20 @@ def run_on_dataset(
will be created using the credentials in the environment. will be created using the credentials in the environment.
Returns: Returns:
A dictionary mapping example ids to the model outputs. A dictionary containing the run's session name and the resulting model outputs.
""" """
client_ = client or LangChainPlusClient() client_ = client or LangChainPlusClient()
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name) session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
dataset = client_.read_dataset(dataset_name=dataset_name) dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id)) examples = client_.list_examples(dataset_id=str(dataset.id))
return run_on_examples( results = run_on_examples(
examples, examples,
llm_or_chain_factory, llm_or_chain_factory,
num_repetitions=num_repetitions, num_repetitions=num_repetitions,
session_name=session_name, session_name=session_name,
verbose=verbose, verbose=verbose,
) )
return {
"session_name": session_name,
"results": results,
}

View File

@ -212,6 +212,8 @@
" error=False, # Only runs that succeed\n", " error=False, # Only runs that succeed\n",
")\n", ")\n",
"for run in runs:\n", "for run in runs:\n",
" if run.outputs is None:\n",
" continue\n",
" try:\n", " try:\n",
" client.create_example(\n", " client.create_example(\n",
" inputs=run.inputs, outputs=run.outputs, dataset_id=dataset.id\n", " inputs=run.inputs, outputs=run.outputs, dataset_id=dataset.id\n",
@ -388,7 +390,7 @@
" client will be created using the credentials in the environment.\n", " client will be created using the credentials in the environment.\n",
"\n", "\n",
"Returns:\n", "Returns:\n",
" A dictionary mapping example ids to the model outputs.\n", " A dictionary containing the run's session name and the resulting model outputs.\n",
"\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/runner_utils.py\n", "\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/runner_utils.py\n",
"\u001b[0;31mType:\u001b[0m function" "\u001b[0;31mType:\u001b[0m function"
] ]
@ -438,16 +440,14 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Processed examples: 3\r" "Processed examples: 4\r"
] ]
}, },
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Chain failed for example 59fb1b4d-d935-4e43-b2a7-bc33fde841bb. Error: LLMMathChain._evaluate(\"\n", "Chain failed for example c855f923-4165-4fe0-a909-360749f3f764. Error: Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n"
"round(0.2791714614499425, 2)\n",
"\") raised error: 'VariableNode' object is not callable. Please try again with a valid numerical expression\n"
] ]
}, },
{ {
@ -459,13 +459,11 @@
} }
], ],
"source": [ "source": [
"evaluation_session_name = \"Search + Calculator Agent Evaluation\"\n",
"chain_results = await arun_on_dataset(\n", "chain_results = await arun_on_dataset(\n",
" dataset_name=dataset_name,\n", " dataset_name=dataset_name,\n",
" llm_or_chain_factory=chain_factory,\n", " llm_or_chain_factory=chain_factory,\n",
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n", " concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
" verbose=True,\n", " verbose=True,\n",
" session_name=evaluation_session_name, # Optional, a unique session name will be generated if not provided\n",
" client=client,\n", " client=client,\n",
")\n", ")\n",
"\n", "\n",
@ -558,7 +556,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": 16,
"id": "4c94a738-dcd3-442e-b8e7-dd36459f56e3", "id": "4c94a738-dcd3-442e-b8e7-dd36459f56e3",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -567,7 +565,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "a185493c1af74cbaa0f9b10f32cf81c6", "model_id": "9989f6507cd04ea7a09ea3c5723dc984",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@ -582,8 +580,10 @@
"source": [ "source": [
"from tqdm.notebook import tqdm\n", "from tqdm.notebook import tqdm\n",
"feedbacks = []\n", "feedbacks = []\n",
"runs = client.list_runs(session_name=evaluation_session_name, execution_order=1, error=False)\n", "runs = client.list_runs(session_name=chain_results[\"session_name\"], execution_order=1, error=False)\n",
"for run in tqdm(runs):\n", "for run in tqdm(runs):\n",
" if run.outputs is None:\n",
" continue\n",
" eval_feedback = []\n", " eval_feedback = []\n",
" for evaluator in evaluators:\n", " for evaluator in evaluators:\n",
" eval_feedback.append(client.aevaluate_run(run, evaluator))\n", " eval_feedback.append(client.aevaluate_run(run, evaluator))\n",
@ -592,26 +592,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": null,
"id": "8696f167-dc75-4ef8-8bb3-ac1ce8324f30", "id": "8696f167-dc75-4ef8-8bb3-ac1ce8324f30",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"data": {
"text/html": [
"<a href=\"https://dev.langchain.plus\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
],
"text/plain": [
"LangChainPlusClient (API URL: https://dev.api.langchain.plus)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"client" "client"
] ]

View File

@ -201,4 +201,4 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
] ]
for uuid_ in uuids for uuid_ in uuids
} }
assert results == expected assert results["results"] == expected