Tracing Group (#5326)

Add context manager to group all runs under a virtual parent

---------

Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com>
searx_updates
Ankush Gola 12 months ago committed by GitHub
parent d5b1608216
commit 84a46753ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,9 +5,20 @@ import functools
import logging
import os
import warnings
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast
from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
from uuid import UUID, uuid4
import langchain
@ -116,6 +127,58 @@ def tracing_v2_enabled(
tracing_v2_callback_var.set(None)
@contextmanager
def trace_as_chain_group(
group_name: str,
*,
session_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
tenant_id: Optional[str] = None,
session_extra: Optional[Dict[str, Any]] = None,
) -> Generator[CallbackManager, None, None]:
"""Get a callback manager for a chain group in a context manager."""
cb = LangChainTracer(
tenant_id=tenant_id,
session_name=session_name,
example_id=example_id,
session_extra=session_extra,
)
cm = CallbackManager.configure(
inheritable_callbacks=[cb],
)
run_manager = cm.on_chain_start({"name": group_name}, {})
yield run_manager.get_child()
run_manager.on_chain_end({})
@asynccontextmanager
async def atrace_as_chain_group(
group_name: str,
*,
session_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
tenant_id: Optional[str] = None,
session_extra: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[AsyncCallbackManager, None]:
"""Get a callback manager for a chain group in a context manager."""
cb = LangChainTracer(
tenant_id=tenant_id,
session_name=session_name,
example_id=example_id,
session_extra=session_extra,
)
cm = AsyncCallbackManager.configure(
inheritable_callbacks=[cb],
)
run_manager = await cm.on_chain_start({"name": group_name}, {})
try:
yield run_manager.get_child()
finally:
await run_manager.on_chain_end({})
def _handle_event(
handlers: List[BaseCallbackHandler],
event_name: str,

@ -5,7 +5,7 @@ import logging
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
import requests
@ -68,7 +68,7 @@ class LangChainTracer(BaseTracer):
def __init__(
self,
example_id: Optional[UUID] = None,
example_id: Optional[Union[UUID, str]] = None,
session_name: Optional[str] = None,
**kwargs: Any,
) -> None:
@ -77,7 +77,9 @@ class LangChainTracer(BaseTracer):
self.session: Optional[TracerSession] = None
self._endpoint = get_endpoint()
self._headers = get_headers()
self.example_id = example_id
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
# set max_workers to 1 to process tasks in order
self.executor = ThreadPoolExecutor(max_workers=1)

@ -7,9 +7,15 @@ from aiohttp import ClientSession
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks import tracing_enabled
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.callbacks.manager import (
atrace_as_chain_group,
trace_as_chain_group,
tracing_v2_enabled,
)
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
questions = [
(
@ -152,3 +158,59 @@ def test_tracing_v2_context_manager() -> None:
agent.run(questions[0]) # this should be traced
agent.run(questions[0]) # this should not be traced
def test_trace_as_group() -> None:
llm = OpenAI(temperature=0.9)
prompt = PromptTemplate(
input_variables=["product"],
template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=llm, prompt=prompt)
with trace_as_chain_group("my_group") as group_manager:
chain.run(product="cars", callbacks=group_manager)
chain.run(product="computers", callbacks=group_manager)
chain.run(product="toys", callbacks=group_manager)
with trace_as_chain_group("my_group_2") as group_manager:
chain.run(product="toys", callbacks=group_manager)
def test_trace_as_group_with_env_set() -> None:
os.environ["LANGCHAIN_TRACING_V2"] = "true"
llm = OpenAI(temperature=0.9)
prompt = PromptTemplate(
input_variables=["product"],
template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=llm, prompt=prompt)
with trace_as_chain_group("my_group") as group_manager:
chain.run(product="cars", callbacks=group_manager)
chain.run(product="computers", callbacks=group_manager)
chain.run(product="toys", callbacks=group_manager)
with trace_as_chain_group("my_group_2") as group_manager:
chain.run(product="toys", callbacks=group_manager)
@pytest.mark.asyncio
async def test_trace_as_group_async() -> None:
llm = OpenAI(temperature=0.9)
prompt = PromptTemplate(
input_variables=["product"],
template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=llm, prompt=prompt)
async with atrace_as_chain_group("my_group") as group_manager:
await chain.arun(product="cars", callbacks=group_manager)
await chain.arun(product="computers", callbacks=group_manager)
await chain.arun(product="toys", callbacks=group_manager)
async with atrace_as_chain_group("my_group_2") as group_manager:
await asyncio.gather(
*[
chain.arun(product="toys", callbacks=group_manager),
chain.arun(product="computers", callbacks=group_manager),
chain.arun(product="cars", callbacks=group_manager),
]
)

Loading…
Cancel
Save