Return session name in runner response (#6112)

Makes it easier to then run evals w/o thinking about specifying a
session
This commit is contained in:
Zander Chase 2023-06-13 16:59:43 -07:00 committed by GitHub
parent e74733ab9e
commit b3b155d488
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 31 deletions

View File

@ -422,14 +422,14 @@ async def arun_on_dataset(
client will be created using the credentials in the environment.
Returns:
A dictionary mapping example ids to the model outputs.
A dictionary containing the run's session name and the resulting model outputs.
"""
client_ = client or LangChainPlusClient()
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id))
return await arun_on_examples(
results = await arun_on_examples(
examples,
llm_or_chain_factory,
concurrency_level=concurrency_level,
@ -437,6 +437,10 @@ async def arun_on_dataset(
session_name=session_name,
verbose=verbose,
)
return {
"session_name": session_name,
"results": results,
}
def run_on_dataset(
@ -466,16 +470,20 @@ def run_on_dataset(
will be created using the credentials in the environment.
Returns:
A dictionary mapping example ids to the model outputs.
A dictionary containing the run's session name and the resulting model outputs.
"""
client_ = client or LangChainPlusClient()
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
dataset = client_.read_dataset(dataset_name=dataset_name)
examples = client_.list_examples(dataset_id=str(dataset.id))
return run_on_examples(
results = run_on_examples(
examples,
llm_or_chain_factory,
num_repetitions=num_repetitions,
session_name=session_name,
verbose=verbose,
)
return {
"session_name": session_name,
"results": results,
}

View File

@ -212,6 +212,8 @@
" error=False, # Only runs that succeed\n",
")\n",
"for run in runs:\n",
" if run.outputs is None:\n",
" continue\n",
" try:\n",
" client.create_example(\n",
" inputs=run.inputs, outputs=run.outputs, dataset_id=dataset.id\n",
@ -388,7 +390,7 @@
" client will be created using the credentials in the environment.\n",
"\n",
"Returns:\n",
" A dictionary mapping example ids to the model outputs.\n",
" A dictionary containing the run's session name and the resulting model outputs.\n",
"\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/runner_utils.py\n",
"\u001b[0;31mType:\u001b[0m function"
]
@ -438,16 +440,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 3\r"
"Processed examples: 4\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Chain failed for example 59fb1b4d-d935-4e43-b2a7-bc33fde841bb. Error: LLMMathChain._evaluate(\"\n",
"round(0.2791714614499425, 2)\n",
"\") raised error: 'VariableNode' object is not callable. Please try again with a valid numerical expression\n"
"Chain failed for example c855f923-4165-4fe0-a909-360749f3f764. Error: Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n"
]
},
{
@ -459,13 +459,11 @@
}
],
"source": [
"evaluation_session_name = \"Search + Calculator Agent Evaluation\"\n",
"chain_results = await arun_on_dataset(\n",
" dataset_name=dataset_name,\n",
" llm_or_chain_factory=chain_factory,\n",
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
" verbose=True,\n",
" session_name=evaluation_session_name, # Optional, a unique session name will be generated if not provided\n",
" client=client,\n",
")\n",
"\n",
@ -558,7 +556,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 16,
"id": "4c94a738-dcd3-442e-b8e7-dd36459f56e3",
"metadata": {
"tags": []
@ -567,7 +565,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a185493c1af74cbaa0f9b10f32cf81c6",
"model_id": "9989f6507cd04ea7a09ea3c5723dc984",
"version_major": 2,
"version_minor": 0
},
@ -582,8 +580,10 @@
"source": [
"from tqdm.notebook import tqdm\n",
"feedbacks = []\n",
"runs = client.list_runs(session_name=evaluation_session_name, execution_order=1, error=False)\n",
"runs = client.list_runs(session_name=chain_results[\"session_name\"], execution_order=1, error=False)\n",
"for run in tqdm(runs):\n",
" if run.outputs is None:\n",
" continue\n",
" eval_feedback = []\n",
" for evaluator in evaluators:\n",
" eval_feedback.append(client.aevaluate_run(run, evaluator))\n",
@ -592,26 +592,12 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": null,
"id": "8696f167-dc75-4ef8-8bb3-ac1ce8324f30",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<a href=\"https://dev.langchain.plus\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
],
"text/plain": [
"LangChainPlusClient (API URL: https://dev.api.langchain.plus)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"client"
]

View File

@ -201,4 +201,4 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
]
for uuid_ in uuids
}
assert results == expected
assert results["results"] == expected