mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Tracing Group (#5326)
Add context manager to group all runs under a virtual parent --------- Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com>
This commit is contained in:
parent
d5b1608216
commit
84a46753ab
@ -5,9 +5,20 @@ import functools
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
@ -116,6 +127,58 @@ def tracing_v2_enabled(
|
|||||||
tracing_v2_callback_var.set(None)
|
tracing_v2_callback_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def trace_as_chain_group(
|
||||||
|
group_name: str,
|
||||||
|
*,
|
||||||
|
session_name: Optional[str] = None,
|
||||||
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
|
tenant_id: Optional[str] = None,
|
||||||
|
session_extra: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Generator[CallbackManager, None, None]:
|
||||||
|
"""Get a callback manager for a chain group in a context manager."""
|
||||||
|
cb = LangChainTracer(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
session_name=session_name,
|
||||||
|
example_id=example_id,
|
||||||
|
session_extra=session_extra,
|
||||||
|
)
|
||||||
|
cm = CallbackManager.configure(
|
||||||
|
inheritable_callbacks=[cb],
|
||||||
|
)
|
||||||
|
|
||||||
|
run_manager = cm.on_chain_start({"name": group_name}, {})
|
||||||
|
yield run_manager.get_child()
|
||||||
|
run_manager.on_chain_end({})
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def atrace_as_chain_group(
|
||||||
|
group_name: str,
|
||||||
|
*,
|
||||||
|
session_name: Optional[str] = None,
|
||||||
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
|
tenant_id: Optional[str] = None,
|
||||||
|
session_extra: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> AsyncGenerator[AsyncCallbackManager, None]:
|
||||||
|
"""Get a callback manager for a chain group in a context manager."""
|
||||||
|
cb = LangChainTracer(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
session_name=session_name,
|
||||||
|
example_id=example_id,
|
||||||
|
session_extra=session_extra,
|
||||||
|
)
|
||||||
|
cm = AsyncCallbackManager.configure(
|
||||||
|
inheritable_callbacks=[cb],
|
||||||
|
)
|
||||||
|
|
||||||
|
run_manager = await cm.on_chain_start({"name": group_name}, {})
|
||||||
|
try:
|
||||||
|
yield run_manager.get_child()
|
||||||
|
finally:
|
||||||
|
await run_manager.on_chain_end({})
|
||||||
|
|
||||||
|
|
||||||
def _handle_event(
|
def _handle_event(
|
||||||
handlers: List[BaseCallbackHandler],
|
handlers: List[BaseCallbackHandler],
|
||||||
event_name: str,
|
event_name: str,
|
||||||
|
@ -5,7 +5,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -68,7 +68,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
example_id: Optional[UUID] = None,
|
example_id: Optional[Union[UUID, str]] = None,
|
||||||
session_name: Optional[str] = None,
|
session_name: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -77,7 +77,9 @@ class LangChainTracer(BaseTracer):
|
|||||||
self.session: Optional[TracerSession] = None
|
self.session: Optional[TracerSession] = None
|
||||||
self._endpoint = get_endpoint()
|
self._endpoint = get_endpoint()
|
||||||
self._headers = get_headers()
|
self._headers = get_headers()
|
||||||
self.example_id = example_id
|
self.example_id = (
|
||||||
|
UUID(example_id) if isinstance(example_id, str) else example_id
|
||||||
|
)
|
||||||
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
|
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
|
||||||
# set max_workers to 1 to process tasks in order
|
# set max_workers to 1 to process tasks in order
|
||||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
@ -7,9 +7,15 @@ from aiohttp import ClientSession
|
|||||||
|
|
||||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||||
from langchain.callbacks import tracing_enabled
|
from langchain.callbacks import tracing_enabled
|
||||||
from langchain.callbacks.manager import tracing_v2_enabled
|
from langchain.callbacks.manager import (
|
||||||
|
atrace_as_chain_group,
|
||||||
|
trace_as_chain_group,
|
||||||
|
tracing_v2_enabled,
|
||||||
|
)
|
||||||
|
from langchain.chains import LLMChain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.llms import OpenAI
|
from langchain.llms import OpenAI
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
questions = [
|
questions = [
|
||||||
(
|
(
|
||||||
@ -152,3 +158,59 @@ def test_tracing_v2_context_manager() -> None:
|
|||||||
agent.run(questions[0]) # this should be traced
|
agent.run(questions[0]) # this should be traced
|
||||||
|
|
||||||
agent.run(questions[0]) # this should not be traced
|
agent.run(questions[0]) # this should not be traced
|
||||||
|
|
||||||
|
|
||||||
|
def test_trace_as_group() -> None:
|
||||||
|
llm = OpenAI(temperature=0.9)
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["product"],
|
||||||
|
template="What is a good name for a company that makes {product}?",
|
||||||
|
)
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
with trace_as_chain_group("my_group") as group_manager:
|
||||||
|
chain.run(product="cars", callbacks=group_manager)
|
||||||
|
chain.run(product="computers", callbacks=group_manager)
|
||||||
|
chain.run(product="toys", callbacks=group_manager)
|
||||||
|
|
||||||
|
with trace_as_chain_group("my_group_2") as group_manager:
|
||||||
|
chain.run(product="toys", callbacks=group_manager)
|
||||||
|
|
||||||
|
|
||||||
|
def test_trace_as_group_with_env_set() -> None:
|
||||||
|
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||||
|
llm = OpenAI(temperature=0.9)
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["product"],
|
||||||
|
template="What is a good name for a company that makes {product}?",
|
||||||
|
)
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
with trace_as_chain_group("my_group") as group_manager:
|
||||||
|
chain.run(product="cars", callbacks=group_manager)
|
||||||
|
chain.run(product="computers", callbacks=group_manager)
|
||||||
|
chain.run(product="toys", callbacks=group_manager)
|
||||||
|
|
||||||
|
with trace_as_chain_group("my_group_2") as group_manager:
|
||||||
|
chain.run(product="toys", callbacks=group_manager)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trace_as_group_async() -> None:
|
||||||
|
llm = OpenAI(temperature=0.9)
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["product"],
|
||||||
|
template="What is a good name for a company that makes {product}?",
|
||||||
|
)
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
async with atrace_as_chain_group("my_group") as group_manager:
|
||||||
|
await chain.arun(product="cars", callbacks=group_manager)
|
||||||
|
await chain.arun(product="computers", callbacks=group_manager)
|
||||||
|
await chain.arun(product="toys", callbacks=group_manager)
|
||||||
|
|
||||||
|
async with atrace_as_chain_group("my_group_2") as group_manager:
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
chain.arun(product="toys", callbacks=group_manager),
|
||||||
|
chain.arun(product="computers", callbacks=group_manager),
|
||||||
|
chain.arun(product="cars", callbacks=group_manager),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user