diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 6e3d7da312..9437227bfb 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -109,8 +109,6 @@ class LangChainTracer(BaseTracer): def _persist_run_single(self, run: Run) -> None: """Persist a run.""" - if run.parent_run_id is None: - run.reference_example_id = self.example_id run_dict = run.dict(exclude={"child_runs"}) run_dict["tags"] = self._get_tags(run) extra = run_dict.get("extra", {}) @@ -136,12 +134,16 @@ class LangChainTracer(BaseTracer): def _on_llm_start(self, run: Run) -> None: """Persist an LLM run.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id self._futures.add( self.executor.submit(self._persist_run_single, run.copy(deep=True)) ) def _on_chat_model_start(self, run: Run) -> None: """Persist an LLM run.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id self._futures.add( self.executor.submit(self._persist_run_single, run.copy(deep=True)) ) @@ -160,6 +162,8 @@ class LangChainTracer(BaseTracer): def _on_chain_start(self, run: Run) -> None: """Process the Chain Run upon start.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id self._futures.add( self.executor.submit(self._persist_run_single, run.copy(deep=True)) ) @@ -178,6 +182,8 @@ class LangChainTracer(BaseTracer): def _on_tool_start(self, run: Run) -> None: """Process the Tool Run upon start.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id self._futures.add( self.executor.submit(self._persist_run_single, run.copy(deep=True)) ) @@ -196,6 +202,8 @@ class LangChainTracer(BaseTracer): def _on_retriever_start(self, run: Run) -> None: """Process the Retriever Run upon start.""" + if run.parent_run_id is None: + run.reference_example_id = self.example_id self._futures.add( self.executor.submit(self._persist_run_single, run.copy(deep=True)) ) diff --git a/tests/unit_tests/callbacks/tracers/test_langchain.py b/tests/unit_tests/callbacks/tracers/test_langchain.py new file mode 100644 index 0000000000..1adc857125 --- /dev/null +++ b/tests/unit_tests/callbacks/tracers/test_langchain.py @@ -0,0 +1,51 @@ +import time +import unittest.mock +from typing import Any +from uuid import UUID + +from langchainplus_sdk import LangChainPlusClient + +from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.callbacks.tracers.schemas import Run +from langchain.schema.output import LLMResult + + +def test_example_id_assignment_threadsafe() -> None: + """Test that example assigned at callback start/end is honored.""" + example_ids = {} + + def mock_create_run(self: Any, **kwargs: Any) -> Any: + example_ids[kwargs.get("id")] = kwargs.get("reference_example_id") + return unittest.mock.MagicMock() + + with unittest.mock.patch.object( + LangChainPlusClient, "create_run", new=mock_create_run + ): + client = LangChainPlusClient() + tracer = LangChainTracer(client=client) + old_persist_run_single = tracer._persist_run_single + + def new_persist_run_single(run: Run) -> None: + time.sleep(0.01) + old_persist_run_single(run) + + with unittest.mock.patch.object( + tracer, "_persist_run_single", new=new_persist_run_single + ): + run_id_1 = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a") + run_id_2 = UUID("f1f9fa53-8b2f-4742-bdbc-38215f7bd1e1") + example_id_1 = UUID("57e42c57-8c79-4d9f-8765-bf6cd3a98055") + tracer.example_id = example_id_1 + tracer.on_llm_start({"name": "example_1"}, ["foo"], run_id=run_id_1) + tracer.on_llm_end(LLMResult(generations=[], llm_output={}), run_id=run_id_1) + example_id_2 = UUID("4f31216e-7c26-4027-a5fd-0bbf9ace17dc") + tracer.example_id = example_id_2 + tracer.on_llm_start({"name": "example_2"}, ["foo"], run_id=run_id_2) + tracer.on_llm_end(LLMResult(generations=[], llm_output={}), run_id=run_id_2) + tracer.example_id = None + expected_example_ids = { + run_id_1: example_id_1, + run_id_2: example_id_2, + } + tracer.wait_for_futures() + assert example_ids == expected_example_ids