Fix nested runs example ID (#4413)

#### Only reference example ID on the parent run

Previously, I was assigning the example ID to every child run. 
Adds a test.
parallel_dir_loader
Zander Chase 1 year ago committed by GitHub
parent e4ca511ec8
commit f2150285a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -236,13 +236,13 @@ class LangChainTracerV2(LangChainTracer):
outputs=outputs, outputs=outputs,
session_id=session.id, session_id=session.id,
run_type=run_type, run_type=run_type,
reference_example_id=self.example_id,
child_runs=[self._convert_run(child) for child in child_runs], child_runs=[self._convert_run(child) for child in child_runs],
) )
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run.""" """Persist a run."""
run_create = self._convert_run(run) run_create = self._convert_run(run)
run_create.reference_example_id = self.example_id
try: try:
response = requests.post( response = requests.post(
f"{self._endpoint}/runs", f"{self._endpoint}/runs",

@ -1,6 +1,7 @@
"""Test Tracer classes.""" """Test Tracer classes."""
from __future__ import annotations from __future__ import annotations
import json
from datetime import datetime from datetime import datetime
from typing import List, Tuple, Union from typing import List, Tuple, Union
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -605,6 +606,39 @@ def test_persist_run(
assert get.call_count == 0 assert get.call_count == 0
def test_persist_run_with_example_id(
lang_chain_tracer_v2: LangChainTracerV2,
sample_tracer_session_v2: TracerSessionV2,
sample_runs: Tuple[LLMRun, ChainRun, ToolRun],
) -> None:
"""Test the example ID is assigned only to the parent run and not the children."""
example_id = uuid4()
llm_run, chain_run, tool_run = sample_runs
chain_run.child_tool_runs = [tool_run]
tool_run.child_llm_runs = [llm_run]
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
"langchain.callbacks.tracers.langchain.requests.get"
) as get:
post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
lang_chain_tracer_v2.example_id = example_id
lang_chain_tracer_v2._persist_run(chain_run)
assert post.call_count == 1
assert get.call_count == 0
posted_data = json.loads(post.call_args[1]["data"])
assert posted_data["id"] == chain_run.uuid
assert posted_data["reference_example_id"] == str(example_id)
def assert_child_run_no_example_id(run: dict) -> None:
assert not run.get("reference_example_id")
for child_run in run.get("child_runs", []):
assert_child_run_no_example_id(child_run)
for child_run in posted_data["child_runs"]:
assert_child_run_no_example_id(child_run)
def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None: def test_get_session_create(lang_chain_tracer_v2: LangChainTracerV2) -> None:
"""Test creating the 'SessionCreate' object.""" """Test creating the 'SessionCreate' object."""
lang_chain_tracer_v2.tenant_id = str(_TENANT_ID) lang_chain_tracer_v2.tenant_id = str(_TENANT_ID)

Loading…
Cancel
Save