From e042e5df35288fb2809d9947c9e1ca8c3266cd12 Mon Sep 17 00:00:00 2001 From: Jon Watte Date: Mon, 4 Dec 2023 19:44:50 -0800 Subject: [PATCH] fix: call _on_llm_error() (#13581) Description: There's a copy-paste typo where on_llm_error() calls _on_chain_error() instead of _on_llm_error(). Issue: #13580 Dependencies: None Tag maintainer: @hwchase17 Twitter handle: @jwatte "Run `make format`, `make lint` and `make test` to check this locally." The test scripts don't work in a plain Ubuntu LTS 20.04 system. It looks like the dev container pulling is stuck. Or maybe the internet is just ornery today. --------- Co-authored-by: jwatte Co-authored-by: Harrison Chase --- libs/core/langchain_core/tracers/base.py | 2 +- .../callbacks/tracers/test_base_tracer.py | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index db0301b2a1..ddf9e92e2a 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -224,7 +224,7 @@ class BaseTracer(BaseCallbackHandler, ABC): llm_run.end_time = datetime.utcnow() llm_run.events.append({"name": "error", "time": llm_run.end_time}) self._end_trace(llm_run) - self._on_chain_error(llm_run) + self._on_llm_error(llm_run) return llm_run def on_chain_start( diff --git a/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py b/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py index f658abe260..94be5295c2 100644 --- a/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py +++ b/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py @@ -332,6 +332,42 @@ def test_tracer_llm_run_on_error() -> None: assert tracer.runs == [compare_run] +@freeze_time("2023-01-01") +def test_tracer_llm_run_on_error_callback() -> None: + """Test tracer on an LLM run with an error and a callback.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = Run( + id=str(uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + events=[ + {"name": "start", "time": datetime.utcnow()}, + {"name": "error", "time": datetime.utcnow()}, + ], + extra={}, + execution_order=1, + child_execution_order=1, + serialized=SERIALIZED, + inputs=dict(prompts=[]), + outputs=None, + error=repr(exception), + run_type="llm", + ) + + class FakeTracerWithLlmErrorCallback(FakeTracer): + error_run = None + + def _on_llm_error(self, run: Run) -> None: + self.error_run = run + + tracer = FakeTracerWithLlmErrorCallback() + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) + tracer.on_llm_error(exception, run_id=uuid) + assert tracer.error_run == compare_run + + @freeze_time("2023-01-01") def test_tracer_chain_run_on_error() -> None: """Test tracer on a Chain run with an error."""