Tracing Group (#5326)

Add context manager to group all runs under a virtual parent

---------

Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com>
This commit is contained in:
Ankush Gola 2023-06-05 19:18:43 -07:00 committed by GitHub
parent d5b1608216
commit 84a46753ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 133 additions and 6 deletions

View File

@ -5,9 +5,20 @@ import functools
import logging import logging
import os import os
import warnings import warnings
from contextlib import contextmanager from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast from typing import (
Any,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import langchain import langchain
@ -116,6 +127,58 @@ def tracing_v2_enabled(
tracing_v2_callback_var.set(None) tracing_v2_callback_var.set(None)
@contextmanager
def trace_as_chain_group(
group_name: str,
*,
session_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
tenant_id: Optional[str] = None,
session_extra: Optional[Dict[str, Any]] = None,
) -> Generator[CallbackManager, None, None]:
"""Get a callback manager for a chain group in a context manager."""
cb = LangChainTracer(
tenant_id=tenant_id,
session_name=session_name,
example_id=example_id,
session_extra=session_extra,
)
cm = CallbackManager.configure(
inheritable_callbacks=[cb],
)
run_manager = cm.on_chain_start({"name": group_name}, {})
yield run_manager.get_child()
run_manager.on_chain_end({})
@asynccontextmanager
async def atrace_as_chain_group(
group_name: str,
*,
session_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
tenant_id: Optional[str] = None,
session_extra: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[AsyncCallbackManager, None]:
"""Get a callback manager for a chain group in a context manager."""
cb = LangChainTracer(
tenant_id=tenant_id,
session_name=session_name,
example_id=example_id,
session_extra=session_extra,
)
cm = AsyncCallbackManager.configure(
inheritable_callbacks=[cb],
)
run_manager = await cm.on_chain_start({"name": group_name}, {})
try:
yield run_manager.get_child()
finally:
await run_manager.on_chain_end({})
def _handle_event( def _handle_event(
handlers: List[BaseCallbackHandler], handlers: List[BaseCallbackHandler],
event_name: str, event_name: str,

View File

@ -5,7 +5,7 @@ import logging
import os import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from uuid import UUID from uuid import UUID
import requests import requests
@ -68,7 +68,7 @@ class LangChainTracer(BaseTracer):
def __init__( def __init__(
self, self,
example_id: Optional[UUID] = None, example_id: Optional[Union[UUID, str]] = None,
session_name: Optional[str] = None, session_name: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -77,7 +77,9 @@ class LangChainTracer(BaseTracer):
self.session: Optional[TracerSession] = None self.session: Optional[TracerSession] = None
self._endpoint = get_endpoint() self._endpoint = get_endpoint()
self._headers = get_headers() self._headers = get_headers()
self.example_id = example_id self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default") self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
# set max_workers to 1 to process tasks in order # set max_workers to 1 to process tasks in order
self.executor = ThreadPoolExecutor(max_workers=1) self.executor = ThreadPoolExecutor(max_workers=1)

View File

@ -7,9 +7,15 @@ from aiohttp import ClientSession
from langchain.agents import AgentType, initialize_agent, load_tools from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks import tracing_enabled from langchain.callbacks import tracing_enabled
from langchain.callbacks.manager import tracing_v2_enabled from langchain.callbacks.manager import (
atrace_as_chain_group,
trace_as_chain_group,
tracing_v2_enabled,
)
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
questions = [ questions = [
( (
@ -152,3 +158,59 @@ def test_tracing_v2_context_manager() -> None:
agent.run(questions[0]) # this should be traced agent.run(questions[0]) # this should be traced
agent.run(questions[0]) # this should not be traced agent.run(questions[0]) # this should not be traced
def test_trace_as_group() -> None:
llm = OpenAI(temperature=0.9)
prompt = PromptTemplate(
input_variables=["product"],
template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=llm, prompt=prompt)
with trace_as_chain_group("my_group") as group_manager:
chain.run(product="cars", callbacks=group_manager)
chain.run(product="computers", callbacks=group_manager)
chain.run(product="toys", callbacks=group_manager)
with trace_as_chain_group("my_group_2") as group_manager:
chain.run(product="toys", callbacks=group_manager)
def test_trace_as_group_with_env_set() -> None:
os.environ["LANGCHAIN_TRACING_V2"] = "true"
llm = OpenAI(temperature=0.9)
prompt = PromptTemplate(
input_variables=["product"],
template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=llm, prompt=prompt)
with trace_as_chain_group("my_group") as group_manager:
chain.run(product="cars", callbacks=group_manager)
chain.run(product="computers", callbacks=group_manager)
chain.run(product="toys", callbacks=group_manager)
with trace_as_chain_group("my_group_2") as group_manager:
chain.run(product="toys", callbacks=group_manager)
@pytest.mark.asyncio
async def test_trace_as_group_async() -> None:
llm = OpenAI(temperature=0.9)
prompt = PromptTemplate(
input_variables=["product"],
template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=llm, prompt=prompt)
async with atrace_as_chain_group("my_group") as group_manager:
await chain.arun(product="cars", callbacks=group_manager)
await chain.arun(product="computers", callbacks=group_manager)
await chain.arun(product="toys", callbacks=group_manager)
async with atrace_as_chain_group("my_group_2") as group_manager:
await asyncio.gather(
*[
chain.arun(product="toys", callbacks=group_manager),
chain.arun(product="computers", callbacks=group_manager),
chain.arun(product="cars", callbacks=group_manager),
]
)