diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index ebcbca8c..95399cc7 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -5,9 +5,20 @@ import functools import logging import os import warnings -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar -from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast +from typing import ( + Any, + AsyncGenerator, + Dict, + Generator, + List, + Optional, + Type, + TypeVar, + Union, + cast, +) from uuid import UUID, uuid4 import langchain @@ -116,6 +127,58 @@ def tracing_v2_enabled( tracing_v2_callback_var.set(None) +@contextmanager +def trace_as_chain_group( + group_name: str, + *, + session_name: Optional[str] = None, + example_id: Optional[Union[str, UUID]] = None, + tenant_id: Optional[str] = None, + session_extra: Optional[Dict[str, Any]] = None, +) -> Generator[CallbackManager, None, None]: + """Get a callback manager for a chain group in a context manager.""" + cb = LangChainTracer( + tenant_id=tenant_id, + session_name=session_name, + example_id=example_id, + session_extra=session_extra, + ) + cm = CallbackManager.configure( + inheritable_callbacks=[cb], + ) + + run_manager = cm.on_chain_start({"name": group_name}, {}) + yield run_manager.get_child() + run_manager.on_chain_end({}) + + +@asynccontextmanager +async def atrace_as_chain_group( + group_name: str, + *, + session_name: Optional[str] = None, + example_id: Optional[Union[str, UUID]] = None, + tenant_id: Optional[str] = None, + session_extra: Optional[Dict[str, Any]] = None, +) -> AsyncGenerator[AsyncCallbackManager, None]: + """Get a callback manager for a chain group in a context manager.""" + cb = LangChainTracer( + tenant_id=tenant_id, + session_name=session_name, + example_id=example_id, + session_extra=session_extra, + ) + cm = AsyncCallbackManager.configure( + inheritable_callbacks=[cb], + ) + + run_manager = await cm.on_chain_start({"name": group_name}, {}) + try: + yield run_manager.get_child() + finally: + await run_manager.on_chain_end({}) + + def _handle_event( handlers: List[BaseCallbackHandler], event_name: str, diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index bbbccfe8..e38490d3 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -5,7 +5,7 @@ import logging import os from concurrent.futures import ThreadPoolExecutor from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from uuid import UUID import requests @@ -68,7 +68,7 @@ class LangChainTracer(BaseTracer): def __init__( self, - example_id: Optional[UUID] = None, + example_id: Optional[Union[UUID, str]] = None, session_name: Optional[str] = None, **kwargs: Any, ) -> None: @@ -77,7 +77,9 @@ class LangChainTracer(BaseTracer): self.session: Optional[TracerSession] = None self._endpoint = get_endpoint() self._headers = get_headers() - self.example_id = example_id + self.example_id = ( + UUID(example_id) if isinstance(example_id, str) else example_id + ) self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default") # set max_workers to 1 to process tasks in order self.executor = ThreadPoolExecutor(max_workers=1) diff --git a/tests/integration_tests/callbacks/test_langchain_tracer.py b/tests/integration_tests/callbacks/test_langchain_tracer.py index ca33cf91..80d18713 100644 --- a/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -7,9 +7,15 @@ from aiohttp import ClientSession from langchain.agents import AgentType, initialize_agent, load_tools from langchain.callbacks import tracing_enabled -from langchain.callbacks.manager import tracing_v2_enabled +from langchain.callbacks.manager import ( + atrace_as_chain_group, + trace_as_chain_group, + tracing_v2_enabled, +) +from langchain.chains import LLMChain from langchain.chat_models import ChatOpenAI from langchain.llms import OpenAI +from langchain.prompts import PromptTemplate questions = [ ( @@ -152,3 +158,59 @@ def test_tracing_v2_context_manager() -> None: agent.run(questions[0]) # this should be traced agent.run(questions[0]) # this should not be traced + + +def test_trace_as_group() -> None: + llm = OpenAI(temperature=0.9) + prompt = PromptTemplate( + input_variables=["product"], + template="What is a good name for a company that makes {product}?", + ) + chain = LLMChain(llm=llm, prompt=prompt) + with trace_as_chain_group("my_group") as group_manager: + chain.run(product="cars", callbacks=group_manager) + chain.run(product="computers", callbacks=group_manager) + chain.run(product="toys", callbacks=group_manager) + + with trace_as_chain_group("my_group_2") as group_manager: + chain.run(product="toys", callbacks=group_manager) + + +def test_trace_as_group_with_env_set() -> None: + os.environ["LANGCHAIN_TRACING_V2"] = "true" + llm = OpenAI(temperature=0.9) + prompt = PromptTemplate( + input_variables=["product"], + template="What is a good name for a company that makes {product}?", + ) + chain = LLMChain(llm=llm, prompt=prompt) + with trace_as_chain_group("my_group") as group_manager: + chain.run(product="cars", callbacks=group_manager) + chain.run(product="computers", callbacks=group_manager) + chain.run(product="toys", callbacks=group_manager) + + with trace_as_chain_group("my_group_2") as group_manager: + chain.run(product="toys", callbacks=group_manager) + + +@pytest.mark.asyncio +async def test_trace_as_group_async() -> None: + llm = OpenAI(temperature=0.9) + prompt = PromptTemplate( + input_variables=["product"], + template="What is a good name for a company that makes {product}?", + ) + chain = LLMChain(llm=llm, prompt=prompt) + async with atrace_as_chain_group("my_group") as group_manager: + await chain.arun(product="cars", callbacks=group_manager) + await chain.arun(product="computers", callbacks=group_manager) + await chain.arun(product="toys", callbacks=group_manager) + + async with atrace_as_chain_group("my_group_2") as group_manager: + await asyncio.gather( + *[ + chain.arun(product="toys", callbacks=group_manager), + chain.arun(product="computers", callbacks=group_manager), + chain.arun(product="cars", callbacks=group_manager), + ] + )