From 907c57e3244e402d28914baf145a5c6fb4213f66 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Mon, 28 Aug 2023 15:30:41 -0700 Subject: [PATCH] Add collect_runs callback (#9885) --- .../langchain/langchain/callbacks/__init__.py | 2 ++ libs/langchain/langchain/callbacks/manager.py | 27 +++++++++++++++++++ .../callbacks/test_run_collector.py | 16 +++++++++++ 3 files changed, 45 insertions(+) create mode 100644 libs/langchain/tests/unit_tests/callbacks/test_run_collector.py diff --git a/libs/langchain/langchain/callbacks/__init__.py b/libs/langchain/langchain/callbacks/__init__.py index 8398741be3..4b9f93bf84 100644 --- a/libs/langchain/langchain/callbacks/__init__.py +++ b/libs/langchain/langchain/callbacks/__init__.py @@ -20,6 +20,7 @@ from langchain.callbacks.human import HumanApprovalCallbackHandler from langchain.callbacks.infino_callback import InfinoCallbackHandler from langchain.callbacks.labelstudio_callback import LabelStudioCallbackHandler from langchain.callbacks.manager import ( + collect_runs, get_openai_callback, tracing_enabled, tracing_v2_enabled, @@ -66,6 +67,7 @@ __all__ = [ "get_openai_callback", "tracing_enabled", "tracing_v2_enabled", + "collect_runs", "wandb_tracing_enabled", "FlyteCallbackHandler", "SageMakerCallbackHandler", diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index 3f22832de3..2f7a7fad47 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -38,6 +38,7 @@ from langchain.callbacks.base import ( ) from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler +from langchain.callbacks.tracers import run_collector from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1 from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler @@ -75,6 +76,11 @@ tracing_v2_callback_var: ContextVar[ ] = ContextVar( # noqa: E501 "tracing_callback_v2", default=None ) +run_collector_var: ContextVar[ + Optional[run_collector.RunCollectorCallbackHandler] +] = ContextVar( # noqa: E501 + "run_collector", default=None +) def _get_debug() -> bool: @@ -184,6 +190,24 @@ def tracing_v2_enabled( tracing_v2_callback_var.set(None) +@contextmanager +def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]: + """Collect all run traces in context. + + Returns: + run_collector.RunCollectorCallbackHandler: The run collector callback handler. + + Example: + >>> with collect_runs() as runs_cb: + chain.invoke("foo") + run_id = runs_cb.traced_runs[0].id + """ + cb = run_collector.RunCollectorCallbackHandler() + run_collector_var.set(cb) + yield cb + run_collector_var.set(None) + + @contextmanager def trace_as_chain_group( group_name: str, @@ -1712,6 +1736,7 @@ def _configure( tracer_project = os.environ.get( "LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default") ) + run_collector_ = run_collector_var.get() debug = _get_debug() if ( verbose @@ -1774,4 +1799,6 @@ def _configure( for handler in callback_manager.handlers ): callback_manager.add_handler(open_ai, True) + if run_collector_ is not None: + callback_manager.add_handler(run_collector_, False) return callback_manager diff --git a/libs/langchain/tests/unit_tests/callbacks/test_run_collector.py b/libs/langchain/tests/unit_tests/callbacks/test_run_collector.py new file mode 100644 index 0000000000..9fd031b7a8 --- /dev/null +++ b/libs/langchain/tests/unit_tests/callbacks/test_run_collector.py @@ -0,0 +1,16 @@ +"""Test the run collector.""" + +import uuid + +from langchain.callbacks import collect_runs +from tests.unit_tests.llms.fake_llm import FakeLLM + + +def test_collect_runs() -> None: + llm = FakeLLM(queries={"hi": "hello"}, sequential_responses=True) + with collect_runs() as cb: + llm.predict("hi") + assert cb.traced_runs + assert len(cb.traced_runs) == 1 + assert isinstance(cb.traced_runs[0].id, uuid.UUID) + assert cb.traced_runs[0].inputs == {"prompts": ["hi"]}