Fix Pickle Error (#12141)

If non-pickleable objects (like locks) get passed to the tracing
callback, they'll fail in the deepcopy. Fallback to a shallow copy in
these instances .
pull/12161/head
William FH 12 months ago committed by GitHub
parent 95a1b598fe
commit 4f23aa677a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -411,11 +411,12 @@ def _handle_event(
handler_name = handler.__class__.__name__
logger.warning(
f"NotImplementedError in {handler_name}.{event_name}"
f" callback: {e}"
f" callback: {repr(e)}"
)
except Exception as e:
logger.warning(
f"Error in {handler.__class__.__name__}.{event_name} callback: {e}"
f"Error in {handler.__class__.__name__}.{event_name} callback:"
f" {repr(e)}"
)
if handler.raise_error:
raise e
@ -496,11 +497,12 @@ async def _ahandle_event_for_handler(
else:
logger.warning(
f"NotImplementedError in {handler.__class__.__name__}.{event_name}"
f" callback: {e}"
f" callback: {repr(e)}"
)
except Exception as e:
logger.warning(
f"Error in {handler.__class__.__name__}.{event_name} callback: {e}"
f"Error in {handler.__class__.__name__}.{event_name} callback:"
f" {repr(e)}"
)
if handler.raise_error:
raise e

@ -64,6 +64,16 @@ def _get_executor() -> ThreadPoolExecutor:
return _EXECUTOR
def _copy(run: Run) -> Run:
"""Copy a run."""
try:
return run.copy(deep=True)
except TypeError:
# Fallback in case the object contains a lock or other
# non-pickleable object
return run.copy()
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
@ -192,63 +202,63 @@ class LangChainTracer(BaseTracer):
"""Persist an LLM run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, run.copy(deep=True))
self._submit(self._persist_run_single, _copy(run))
def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, run.copy(deep=True))
self._submit(self._persist_run_single, _copy(run))
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
self._submit(self._update_run_single, run.copy(deep=True))
self._submit(self._update_run_single, _copy(run))
def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
self._submit(self._update_run_single, run.copy(deep=True))
self._submit(self._update_run_single, _copy(run))
def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, run.copy(deep=True))
self._submit(self._persist_run_single, _copy(run))
def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
self._submit(self._update_run_single, run.copy(deep=True))
self._submit(self._update_run_single, _copy(run))
def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
self._submit(self._update_run_single, run.copy(deep=True))
self._submit(self._update_run_single, _copy(run))
def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, run.copy(deep=True))
self._submit(self._persist_run_single, _copy(run))
def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
self._submit(self._update_run_single, run.copy(deep=True))
self._submit(self._update_run_single, _copy(run))
def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self._submit(self._update_run_single, run.copy(deep=True))
self._submit(self._update_run_single, _copy(run))
def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._submit(self._persist_run_single, run.copy(deep=True))
self._submit(self._persist_run_single, _copy(run))
def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
self._submit(self._update_run_single, run.copy(deep=True))
self._submit(self._update_run_single, _copy(run))
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""
self._submit(self._update_run_single, run.copy(deep=True))
self._submit(self._update_run_single, _copy(run))
def wait_for_futures(self) -> None:
"""Wait for the given futures to complete."""

@ -1,3 +1,4 @@
import threading
import time
import unittest.mock
from typing import Any
@ -47,3 +48,17 @@ def test_example_id_assignment_threadsafe() -> None:
}
tracer.wait_for_futures()
assert example_ids == expected_example_ids
def test_log_lock() -> None:
"""Test that example assigned at callback start/end is honored."""
client = unittest.mock.MagicMock(spec=Client)
tracer = LangChainTracer(client=client)
with unittest.mock.patch.object(tracer, "_persist_run_single", new=lambda _: _):
run_id_1 = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
lock = threading.Lock()
tracer.on_chain_start({"name": "example_1"}, {"input": lock}, run_id=run_id_1)
tracer.on_chain_end({}, run_id=run_id_1)
tracer.wait_for_futures()

Loading…
Cancel
Save