From f2150285a495fc530a7707218ea4980c17a170e5 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Tue, 9 May 2023 12:21:53 -0700 Subject: [PATCH] Fix nested runs example ID (#4413) #### Only reference example ID on the parent run Previously, I was assigning the example ID to every child run. Adds a test. --- langchain/callbacks/tracers/langchain.py | 2 +- .../callbacks/tracers/test_tracer.py | 34 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 41c6446f..1d581d35 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -236,13 +236,13 @@ class LangChainTracerV2(LangChainTracer): outputs=outputs, session_id=session.id, run_type=run_type, - reference_example_id=self.example_id, child_runs=[self._convert_run(child) for child in child_runs], ) def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Persist a run.""" run_create = self._convert_run(run) + run_create.reference_example_id = self.example_id try: response = requests.post( f"{self._endpoint}/runs", diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index 0791885b..488d1f70 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -1,6 +1,7 @@ """Test Tracer classes.""" from __future__ import annotations +import json from datetime import datetime from typing import List, Tuple, Union from unittest.mock import Mock, patch @@ -605,6 +606,39 @@ def test_persist_run( assert get.call_count == 0 +def test_persist_run_with_example_id( + lang_chain_tracer_v2: LangChainTracerV2, + sample_tracer_session_v2: TracerSessionV2, + sample_runs: Tuple[LLMRun, ChainRun, ToolRun], +) -> None: + """Test the example ID is assigned only to the parent run and not the children.""" + example_id = uuid4() + llm_run, chain_run, tool_run = sample_runs + chain_run.child_tool_runs = [tool_run] + tool_run.child_llm_runs = [llm_run] + with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch( + "langchain.callbacks.tracers.langchain.requests.get" + ) as get: + post.return_value.raise_for_status.return_value = None + lang_chain_tracer_v2.session = sample_tracer_session_v2 + lang_chain_tracer_v2.example_id = example_id + lang_chain_tracer_v2._persist_run(chain_run) + + assert post.call_count == 1 + assert get.call_count == 0 + posted_data = json.loads(post.call_args[1]["data"]) + assert posted_data["id"] == chain_run.uuid + assert posted_data["reference_example_id"] == str(example_id) + + def assert_child_run_no_example_id(run: dict) -> None: + assert not run.get("reference_example_id") + for child_run in run.get("child_runs", []): + assert_child_run_no_example_id(child_run) + + for child_run in posted_data["child_runs"]: + assert_child_run_no_example_id(child_run) + + def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None: """Test creating the 'SessionCreate' object.""" lang_chain_tracer_v2.tenant_id = str(_TENANT_ID)