forked from Archives/langchain
Return session name in runner response (#6112)
Makes it easier to then run evals w/o thinking about specifying a session
This commit is contained in:
parent
e74733ab9e
commit
b3b155d488
@ -422,14 +422,14 @@ async def arun_on_dataset(
|
|||||||
client will be created using the credentials in the environment.
|
client will be created using the credentials in the environment.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping example ids to the model outputs.
|
A dictionary containing the run's session name and the resulting model outputs.
|
||||||
"""
|
"""
|
||||||
client_ = client or LangChainPlusClient()
|
client_ = client or LangChainPlusClient()
|
||||||
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
|
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
|
||||||
dataset = client_.read_dataset(dataset_name=dataset_name)
|
dataset = client_.read_dataset(dataset_name=dataset_name)
|
||||||
examples = client_.list_examples(dataset_id=str(dataset.id))
|
examples = client_.list_examples(dataset_id=str(dataset.id))
|
||||||
|
|
||||||
return await arun_on_examples(
|
results = await arun_on_examples(
|
||||||
examples,
|
examples,
|
||||||
llm_or_chain_factory,
|
llm_or_chain_factory,
|
||||||
concurrency_level=concurrency_level,
|
concurrency_level=concurrency_level,
|
||||||
@ -437,6 +437,10 @@ async def arun_on_dataset(
|
|||||||
session_name=session_name,
|
session_name=session_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
|
"session_name": session_name,
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def run_on_dataset(
|
def run_on_dataset(
|
||||||
@ -466,16 +470,20 @@ def run_on_dataset(
|
|||||||
will be created using the credentials in the environment.
|
will be created using the credentials in the environment.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping example ids to the model outputs.
|
A dictionary containing the run's session name and the resulting model outputs.
|
||||||
"""
|
"""
|
||||||
client_ = client or LangChainPlusClient()
|
client_ = client or LangChainPlusClient()
|
||||||
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
|
session_name = _get_session_name(session_name, llm_or_chain_factory, dataset_name)
|
||||||
dataset = client_.read_dataset(dataset_name=dataset_name)
|
dataset = client_.read_dataset(dataset_name=dataset_name)
|
||||||
examples = client_.list_examples(dataset_id=str(dataset.id))
|
examples = client_.list_examples(dataset_id=str(dataset.id))
|
||||||
return run_on_examples(
|
results = run_on_examples(
|
||||||
examples,
|
examples,
|
||||||
llm_or_chain_factory,
|
llm_or_chain_factory,
|
||||||
num_repetitions=num_repetitions,
|
num_repetitions=num_repetitions,
|
||||||
session_name=session_name,
|
session_name=session_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
|
"session_name": session_name,
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
@ -212,6 +212,8 @@
|
|||||||
" error=False, # Only runs that succeed\n",
|
" error=False, # Only runs that succeed\n",
|
||||||
")\n",
|
")\n",
|
||||||
"for run in runs:\n",
|
"for run in runs:\n",
|
||||||
|
" if run.outputs is None:\n",
|
||||||
|
" continue\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" client.create_example(\n",
|
" client.create_example(\n",
|
||||||
" inputs=run.inputs, outputs=run.outputs, dataset_id=dataset.id\n",
|
" inputs=run.inputs, outputs=run.outputs, dataset_id=dataset.id\n",
|
||||||
@ -388,7 +390,7 @@
|
|||||||
" client will be created using the credentials in the environment.\n",
|
" client will be created using the credentials in the environment.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Returns:\n",
|
"Returns:\n",
|
||||||
" A dictionary mapping example ids to the model outputs.\n",
|
" A dictionary containing the run's session name and the resulting model outputs.\n",
|
||||||
"\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/runner_utils.py\n",
|
"\u001b[0;31mFile:\u001b[0m ~/code/lc/lckg/langchain/client/runner_utils.py\n",
|
||||||
"\u001b[0;31mType:\u001b[0m function"
|
"\u001b[0;31mType:\u001b[0m function"
|
||||||
]
|
]
|
||||||
@ -438,16 +440,14 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Processed examples: 3\r"
|
"Processed examples: 4\r"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Chain failed for example 59fb1b4d-d935-4e43-b2a7-bc33fde841bb. Error: LLMMathChain._evaluate(\"\n",
|
"Chain failed for example c855f923-4165-4fe0-a909-360749f3f764. Error: Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n"
|
||||||
"round(0.2791714614499425, 2)\n",
|
|
||||||
"\") raised error: 'VariableNode' object is not callable. Please try again with a valid numerical expression\n"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -459,13 +459,11 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"evaluation_session_name = \"Search + Calculator Agent Evaluation\"\n",
|
|
||||||
"chain_results = await arun_on_dataset(\n",
|
"chain_results = await arun_on_dataset(\n",
|
||||||
" dataset_name=dataset_name,\n",
|
" dataset_name=dataset_name,\n",
|
||||||
" llm_or_chain_factory=chain_factory,\n",
|
" llm_or_chain_factory=chain_factory,\n",
|
||||||
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
|
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
|
||||||
" verbose=True,\n",
|
" verbose=True,\n",
|
||||||
" session_name=evaluation_session_name, # Optional, a unique session name will be generated if not provided\n",
|
|
||||||
" client=client,\n",
|
" client=client,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -558,7 +556,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 27,
|
"execution_count": 16,
|
||||||
"id": "4c94a738-dcd3-442e-b8e7-dd36459f56e3",
|
"id": "4c94a738-dcd3-442e-b8e7-dd36459f56e3",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -567,7 +565,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "a185493c1af74cbaa0f9b10f32cf81c6",
|
"model_id": "9989f6507cd04ea7a09ea3c5723dc984",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
@ -582,8 +580,10 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from tqdm.notebook import tqdm\n",
|
"from tqdm.notebook import tqdm\n",
|
||||||
"feedbacks = []\n",
|
"feedbacks = []\n",
|
||||||
"runs = client.list_runs(session_name=evaluation_session_name, execution_order=1, error=False)\n",
|
"runs = client.list_runs(session_name=chain_results[\"session_name\"], execution_order=1, error=False)\n",
|
||||||
"for run in tqdm(runs):\n",
|
"for run in tqdm(runs):\n",
|
||||||
|
" if run.outputs is None:\n",
|
||||||
|
" continue\n",
|
||||||
" eval_feedback = []\n",
|
" eval_feedback = []\n",
|
||||||
" for evaluator in evaluators:\n",
|
" for evaluator in evaluators:\n",
|
||||||
" eval_feedback.append(client.aevaluate_run(run, evaluator))\n",
|
" eval_feedback.append(client.aevaluate_run(run, evaluator))\n",
|
||||||
@ -592,26 +592,12 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 29,
|
"execution_count": null,
|
||||||
"id": "8696f167-dc75-4ef8-8bb3-ac1ce8324f30",
|
"id": "8696f167-dc75-4ef8-8bb3-ac1ce8324f30",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<a href=\"https://dev.langchain.plus\", target=\"_blank\" rel=\"noopener\">LangChain+ Client</a>"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
"LangChainPlusClient (API URL: https://dev.api.langchain.plus)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 29,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"client"
|
"client"
|
||||||
]
|
]
|
||||||
|
@ -201,4 +201,4 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
]
|
]
|
||||||
for uuid_ in uuids
|
for uuid_ in uuids
|
||||||
}
|
}
|
||||||
assert results == expected
|
assert results["results"] == expected
|
||||||
|
Loading…
Reference in New Issue
Block a user