Make Ref Example Threadsafe (#7383)

Have noticed transient ref example misalignment. I believe this is
caused by the logic of assigning an example within the thread executor
rather than before.
This commit is contained in:
William FH 2023-07-07 21:50:42 -07:00 committed by GitHub
parent 4789c99bc2
commit 612a74eb7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 2 deletions

View File

@ -109,8 +109,6 @@ class LangChainTracer(BaseTracer):
def _persist_run_single(self, run: Run) -> None: def _persist_run_single(self, run: Run) -> None:
"""Persist a run.""" """Persist a run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
run_dict = run.dict(exclude={"child_runs"}) run_dict = run.dict(exclude={"child_runs"})
run_dict["tags"] = self._get_tags(run) run_dict["tags"] = self._get_tags(run)
extra = run_dict.get("extra", {}) extra = run_dict.get("extra", {})
@ -136,12 +134,16 @@ class LangChainTracer(BaseTracer):
def _on_llm_start(self, run: Run) -> None: def _on_llm_start(self, run: Run) -> None:
"""Persist an LLM run.""" """Persist an LLM run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._futures.add( self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True)) self.executor.submit(self._persist_run_single, run.copy(deep=True))
) )
def _on_chat_model_start(self, run: Run) -> None: def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run.""" """Persist an LLM run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._futures.add( self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True)) self.executor.submit(self._persist_run_single, run.copy(deep=True))
) )
@ -160,6 +162,8 @@ class LangChainTracer(BaseTracer):
def _on_chain_start(self, run: Run) -> None: def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start.""" """Process the Chain Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._futures.add( self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True)) self.executor.submit(self._persist_run_single, run.copy(deep=True))
) )
@ -178,6 +182,8 @@ class LangChainTracer(BaseTracer):
def _on_tool_start(self, run: Run) -> None: def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start.""" """Process the Tool Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._futures.add( self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True)) self.executor.submit(self._persist_run_single, run.copy(deep=True))
) )
@ -196,6 +202,8 @@ class LangChainTracer(BaseTracer):
def _on_retriever_start(self, run: Run) -> None: def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start.""" """Process the Retriever Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._futures.add( self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True)) self.executor.submit(self._persist_run_single, run.copy(deep=True))
) )

View File

@ -0,0 +1,51 @@
import time
import unittest.mock
from typing import Any
from uuid import UUID
from langchainplus_sdk import LangChainPlusClient
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.schema.output import LLMResult
def test_example_id_assignment_threadsafe() -> None:
"""Test that example assigned at callback start/end is honored."""
example_ids = {}
def mock_create_run(self: Any, **kwargs: Any) -> Any:
example_ids[kwargs.get("id")] = kwargs.get("reference_example_id")
return unittest.mock.MagicMock()
with unittest.mock.patch.object(
LangChainPlusClient, "create_run", new=mock_create_run
):
client = LangChainPlusClient()
tracer = LangChainTracer(client=client)
old_persist_run_single = tracer._persist_run_single
def new_persist_run_single(run: Run) -> None:
time.sleep(0.01)
old_persist_run_single(run)
with unittest.mock.patch.object(
tracer, "_persist_run_single", new=new_persist_run_single
):
run_id_1 = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
run_id_2 = UUID("f1f9fa53-8b2f-4742-bdbc-38215f7bd1e1")
example_id_1 = UUID("57e42c57-8c79-4d9f-8765-bf6cd3a98055")
tracer.example_id = example_id_1
tracer.on_llm_start({"name": "example_1"}, ["foo"], run_id=run_id_1)
tracer.on_llm_end(LLMResult(generations=[], llm_output={}), run_id=run_id_1)
example_id_2 = UUID("4f31216e-7c26-4027-a5fd-0bbf9ace17dc")
tracer.example_id = example_id_2
tracer.on_llm_start({"name": "example_2"}, ["foo"], run_id=run_id_2)
tracer.on_llm_end(LLMResult(generations=[], llm_output={}), run_id=run_id_2)
tracer.example_id = None
expected_example_ids = {
run_id_1: example_id_1,
run_id_2: example_id_2,
}
tracer.wait_for_futures()
assert example_ids == expected_example_ids