Add collect_runs callback (#9885)

pull/9871/head
William FH 1 year ago committed by GitHub
parent 3103f07e03
commit 907c57e324
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,6 +20,7 @@ from langchain.callbacks.human import HumanApprovalCallbackHandler
from langchain.callbacks.infino_callback import InfinoCallbackHandler
from langchain.callbacks.labelstudio_callback import LabelStudioCallbackHandler
from langchain.callbacks.manager import (
collect_runs,
get_openai_callback,
tracing_enabled,
tracing_v2_enabled,
@ -66,6 +67,7 @@ __all__ = [
"get_openai_callback",
"tracing_enabled",
"tracing_v2_enabled",
"collect_runs",
"wandb_tracing_enabled",
"FlyteCallbackHandler",
"SageMakerCallbackHandler",

@ -38,6 +38,7 @@ from langchain.callbacks.base import (
)
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers import run_collector
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
@ -75,6 +76,11 @@ tracing_v2_callback_var: ContextVar[
] = ContextVar( # noqa: E501
"tracing_callback_v2", default=None
)
run_collector_var: ContextVar[
Optional[run_collector.RunCollectorCallbackHandler]
] = ContextVar( # noqa: E501
"run_collector", default=None
)
def _get_debug() -> bool:
@ -184,6 +190,24 @@ def tracing_v2_enabled(
tracing_v2_callback_var.set(None)
@contextmanager
def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]:
"""Collect all run traces in context.
Returns:
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
Example:
>>> with collect_runs() as runs_cb:
chain.invoke("foo")
run_id = runs_cb.traced_runs[0].id
"""
cb = run_collector.RunCollectorCallbackHandler()
run_collector_var.set(cb)
yield cb
run_collector_var.set(None)
@contextmanager
def trace_as_chain_group(
group_name: str,
@ -1712,6 +1736,7 @@ def _configure(
tracer_project = os.environ.get(
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
)
run_collector_ = run_collector_var.get()
debug = _get_debug()
if (
verbose
@ -1774,4 +1799,6 @@ def _configure(
for handler in callback_manager.handlers
):
callback_manager.add_handler(open_ai, True)
if run_collector_ is not None:
callback_manager.add_handler(run_collector_, False)
return callback_manager

@ -0,0 +1,16 @@
"""Test the run collector."""
import uuid
from langchain.callbacks import collect_runs
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_collect_runs() -> None:
llm = FakeLLM(queries={"hi": "hello"}, sequential_responses=True)
with collect_runs() as cb:
llm.predict("hi")
assert cb.traced_runs
assert len(cb.traced_runs) == 1
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
assert cb.traced_runs[0].inputs == {"prompts": ["hi"]}
Loading…
Cancel
Save