core[patch], langchain[patch], community[patch]: Revert #15326 (#15546)

pull/15547/head
Bagatur 5 months ago committed by GitHub
parent 7a93356cbc
commit b2f15738dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
"""Tracers that record execution of LangChain runs."""
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.stdout import (
ConsoleCallbackHandler,
FunctionCallbackHandler,
@ -12,5 +13,6 @@ __all__ = [
"ConsoleCallbackHandler",
"FunctionCallbackHandler",
"LangChainTracer",
"LangChainTracerV1",
"WandbTracer",
]

@ -5,7 +5,7 @@ import os
from aiohttp import ClientSession
from langchain_core.callbacks.manager import atrace_as_chain_group, trace_as_chain_group
from langchain_core.prompts import PromptTemplate
from langchain_core.tracers.context import tracing_v2_enabled
from langchain_core.tracers.context import tracing_enabled, tracing_v2_enabled
from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms import OpenAI
@ -76,6 +76,63 @@ async def test_tracing_concurrent() -> None:
await aiosession.close()
async def test_tracing_concurrent_bw_compat_environ() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools
os.environ["LANGCHAIN_HANDLER"] = "langchain"
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
aiosession = ClientSession()
llm = OpenAI(temperature=0)
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
tasks = [agent.arun(q) for q in questions[:3]]
await asyncio.gather(*tasks)
await aiosession.close()
if "LANGCHAIN_HANDLER" in os.environ:
del os.environ["LANGCHAIN_HANDLER"]
def test_tracing_context_manager() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools
llm = OpenAI(temperature=0)
tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
with tracing_enabled() as session:
assert session
agent.run(questions[0]) # this should be traced
agent.run(questions[0]) # this should not be traced
async def test_tracing_context_manager_async() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools
llm = OpenAI(temperature=0)
async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
# start a background task
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
with tracing_enabled() as session:
assert session
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
await asyncio.gather(*tasks)
await task
async def test_tracing_v2_environment_variable() -> None:
from langchain.agents import AgentType, initialize_agent, load_tools

@ -1847,6 +1847,7 @@ def _configure(
_configure_hooks,
_get_tracer_project,
_tracing_v2_is_enabled,
tracing_callback_var,
tracing_v2_callback_var,
)
@ -1885,12 +1886,20 @@ def _configure(
callback_manager.add_metadata(inheritable_metadata or {})
callback_manager.add_metadata(local_metadata or {}, False)
tracer = tracing_callback_var.get()
tracing_enabled_ = (
env_var_is_set("LANGCHAIN_TRACING")
or tracer is not None
or env_var_is_set("LANGCHAIN_HANDLER")
)
tracer_v2 = tracing_v2_callback_var.get()
tracing_v2_enabled_ = _tracing_v2_is_enabled()
tracer_project = _get_tracer_project()
debug = _get_debug()
if verbose or debug or tracing_v2_enabled_:
if verbose or debug or tracing_enabled_ or tracing_v2_enabled_:
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.stdout import ConsoleCallbackHandler
if verbose and not any(
@ -1898,7 +1907,6 @@ def _configure(
for handler in callback_manager.handlers
):
if debug:
# We will use ConsoleCallbackHandler instead of StdOutCallbackHandler
pass
else:
callback_manager.add_handler(StdOutCallbackHandler(), False)
@ -1907,6 +1915,16 @@ def _configure(
for handler in callback_manager.handlers
):
callback_manager.add_handler(ConsoleCallbackHandler(), True)
if tracing_enabled_ and not any(
isinstance(handler, LangChainTracerV1)
for handler in callback_manager.handlers
):
if tracer:
callback_manager.add_handler(tracer, True)
else:
handler = LangChainTracerV1()
handler.load_session(tracer_project)
callback_manager.add_handler(handler, True)
if tracing_v2_enabled_ and not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers

@ -19,7 +19,9 @@ from langsmith import utils as ls_utils
from langsmith.run_helpers import get_run_tree_context
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
from langchain_core.tracers.schemas import TracerSessionV1
from langchain_core.utils.env import env_var_is_set
if TYPE_CHECKING:
@ -28,6 +30,10 @@ if TYPE_CHECKING:
from langchain_core.callbacks.base import BaseCallbackHandler, Callbacks
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501
"tracing_callback", default=None
)
tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
"tracing_callback_v2", default=None
)
@ -36,6 +42,32 @@ run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVa
)
@contextmanager
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSessionV1, None, None]:
"""Get the Deprecated LangChainTracer in a context manager.
Args:
session_name (str, optional): The name of the session.
Defaults to "default".
Returns:
TracerSessionV1: The LangChainTracer session.
Example:
>>> with tracing_enabled() as session:
... # Use the LangChainTracer session
"""
cb = LangChainTracerV1()
session = cast(TracerSessionV1, cb.load_session(session_name))
try:
tracing_callback_var.set(cb)
yield session
finally:
tracing_callback_var.set(None)
@contextmanager
def tracing_v2_enabled(
project_name: Optional[str] = None,
@ -136,7 +168,6 @@ def _tracing_v2_is_enabled() -> bool:
env_var_is_set("LANGCHAIN_TRACING_V2")
or tracing_v2_callback_var.get() is not None
or get_run_tree_context() is not None
or env_var_is_set("LANGCHAIN_TRACING")
)

@ -0,0 +1,185 @@
from __future__ import annotations
import logging
import os
from typing import Any, Dict, Optional, Union
import requests
from langchain_core.messages import get_buffer_string
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import (
ChainRun,
LLMRun,
Run,
ToolRun,
TracerSession,
TracerSessionV1,
TracerSessionV1Base,
)
from langchain_core.utils import raise_for_status_with_text
logger = logging.getLogger(__name__)
def get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
return headers
def _get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
class LangChainTracerV1(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self.session: Optional[TracerSessionV1] = None
self._endpoint = _get_endpoint()
self._headers = get_headers()
def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]:
session = self.session or self.load_default_session()
if not isinstance(session, TracerSessionV1):
raise ValueError(
"LangChainTracerV1 is not compatible with"
f" session of type {type(session)}"
)
if run.run_type == "llm":
if "prompts" in run.inputs:
prompts = run.inputs["prompts"]
elif "messages" in run.inputs:
prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]]
else:
raise ValueError("No prompts found in LLM run inputs")
return LLMRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
error=run.error,
prompts=prompts,
response=run.outputs if run.outputs else None,
)
if run.run_type == "chain":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ChainRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
inputs=run.inputs,
outputs=run.outputs,
error=run.error,
extra=run.extra,
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
child_chain_runs=[
run for run in child_runs if isinstance(run, ChainRun)
],
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
)
if run.run_type == "tool":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ToolRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
action=str(run.serialized),
tool_input=run.inputs.get("input", ""),
output=None if run.outputs is None else run.outputs.get("output"),
error=run.error,
extra=run.extra,
child_chain_runs=[
run for run in child_runs if isinstance(run, ChainRun)
],
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
)
raise ValueError(f"Unknown run type: {run.run_type}")
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, Run):
v1_run = self._convert_to_v1_run(run)
else:
v1_run = run
if isinstance(v1_run, LLMRun):
endpoint = f"{self._endpoint}/llm-runs"
elif isinstance(v1_run, ChainRun):
endpoint = f"{self._endpoint}/chain-runs"
else:
endpoint = f"{self._endpoint}/tool-runs"
try:
response = requests.post(
endpoint,
data=v1_run.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
except Exception as e:
logger.warning(f"Failed to persist run: {e}")
def _persist_session(
self, session_create: TracerSessionV1Base
) -> Union[TracerSessionV1, TracerSession]:
"""Persist a session."""
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSessionV1(id=r.json()["id"], **session_create.dict())
except Exception as e:
logger.warning(f"Failed to create session, using default session: {e}")
session = TracerSessionV1(id=1, **session_create.dict())
return session
def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1:
"""Load a session from the tracer."""
try:
url = f"{self._endpoint}/sessions"
if session_name:
url += f"?name={session_name}"
r = requests.get(url, headers=self._headers)
tracer_session = TracerSessionV1(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logger.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSessionV1(id=1)
self.session = tracer_session
return tracer_session
def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]:
"""Load a session with the given name from the tracer."""
return self._load_session(session_name)
def load_default_session(self) -> Union[TracerSessionV1, TracerSession]:
"""Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default")

@ -1,12 +1,102 @@
"""Schemas for tracers."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import datetime
import warnings
from typing import Any, Dict, List, Optional, Type
from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.outputs import LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
def RunTypeEnum() -> Type[RunTypeEnumDep]:
"""RunTypeEnum."""
warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead"
" (e.g. 'llm', 'chain', 'tool').",
DeprecationWarning,
)
return RunTypeEnumDep
class TracerSessionV1Base(BaseModel):
"""Base class for TracerSessionV1."""
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
name: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
class TracerSessionV1Create(TracerSessionV1Base):
"""Create class for TracerSessionV1."""
class TracerSessionV1(TracerSessionV1Base):
"""TracerSessionV1 schema."""
id: int
class TracerSessionBase(TracerSessionV1Base):
"""Base class for TracerSession."""
tenant_id: UUID
class TracerSession(TracerSessionBase):
"""TracerSessionV1 schema for the V2 API."""
id: UUID
class BaseRun(BaseModel):
"""Base class for Run."""
uuid: str
parent_uuid: Optional[str] = None
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: Optional[Dict[str, Any]] = None
execution_order: int
child_execution_order: int
serialized: Dict[str, Any]
session_id: int
error: Optional[str] = None
class LLMRun(BaseRun):
"""Class for LLMRun."""
prompts: List[str]
response: Optional[LLMResult] = None
class ChainRun(BaseRun):
"""Class for ChainRun."""
inputs: Dict[str, Any]
outputs: Optional[Dict[str, Any]] = None
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
class ToolRun(BaseRun):
"""Class for ToolRun."""
tool_input: str
output: Optional[str] = None
action: str
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
# Begin V2 API Schemas
class Run(BaseRunV2):
@ -33,8 +123,20 @@ class Run(BaseRunV2):
return values
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()
Run.update_forward_refs()
__all__ = [
"BaseRun",
"ChainRun",
"LLMRun",
"Run",
"RunTypeEnum",
"ToolRun",
"TracerSession",
"TracerSessionBase",
"TracerSessionV1",
"TracerSessionV1Base",
"TracerSessionV1Create",
]

@ -2,7 +2,7 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import List
from typing import Any, List
from uuid import uuid4
import pytest
@ -31,17 +31,17 @@ class FakeTracer(BaseTracer):
self.runs.append(run)
def _compare_run_with_error(run: Run, expected_run: Run) -> None:
def _compare_run_with_error(run: Any, expected_run: Any) -> None:
if run.child_runs:
assert len(expected_run.child_runs) == len(run.child_runs)
for received, expected in zip(run.child_runs, expected_run.child_runs):
_compare_run_with_error(received, expected)
received_dict = run.dict(exclude={"child_runs"})
received_err = received_dict.pop("error")
expected_dict = expected_run.dict(exclude={"child_runs"})
expected_err = expected_dict.pop("error")
received = run.dict(exclude={"child_runs"})
received_err = received.pop("error")
expected = expected_run.dict(exclude={"child_runs"})
expected_err = expected.pop("error")
assert received_dict == expected_dict
assert received == expected
if expected_err is not None:
assert received_err is not None
assert expected_err in received_err
@ -406,7 +406,6 @@ def test_tracer_llm_run_on_error_callback() -> None:
tracer = FakeTracerWithLlmErrorCallback()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.error_run is not None
_compare_run_with_error(tracer.error_run, compare_run)

@ -0,0 +1,562 @@
"""Test Tracer classes."""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, List, Optional, Union
from uuid import uuid4
import pytest
from freezegun import freeze_time
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import HumanMessage
from langchain_core.outputs import LLMResult
from langchain_core.tracers.base import BaseTracer, TracerException
from langchain_core.tracers.langchain_v1 import (
ChainRun,
LangChainTracerV1,
LLMRun,
ToolRun,
TracerSessionV1,
)
from langchain_core.tracers.schemas import Run, TracerSessionV1Base
TEST_SESSION_ID = 2023
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
def load_session(session_name: str) -> TracerSessionV1:
"""Load a tracing session."""
return TracerSessionV1(
id=TEST_SESSION_ID, name=session_name, start_time=datetime.now(timezone.utc)
)
def new_session(name: Optional[str] = None) -> TracerSessionV1:
"""Create a new tracing session."""
return TracerSessionV1(
id=TEST_SESSION_ID,
name=name or "default",
start_time=datetime.now(timezone.utc),
)
def _persist_session(session: TracerSessionV1Base) -> TracerSessionV1:
"""Persist a tracing session."""
return TracerSessionV1(**{**session.dict(), "id": TEST_SESSION_ID})
def load_default_session() -> TracerSessionV1:
"""Load a tracing session."""
return TracerSessionV1(
id=TEST_SESSION_ID, name="default", start_time=datetime.now(timezone.utc)
)
@pytest.fixture
def lang_chain_tracer_v1(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV1:
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracerV1()
return tracer
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, Run):
with pytest.MonkeyPatch().context() as m:
m.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
m.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
m.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracerV1()
tracer.load_default_session = load_default_session # type: ignore
run = tracer._convert_to_v1_run(run)
self.runs.append(run)
def _persist_session(self, session: TracerSessionV1Base) -> TracerSessionV1:
"""Persist a tracing session."""
return _persist_session(session)
def new_session(self, name: Optional[str] = None) -> TracerSessionV1:
"""Create a new tracing session."""
return new_session(name)
def load_session(self, session_name: str) -> TracerSessionV1:
"""Load a tracing session."""
return load_session(session_name)
def load_default_session(self) -> TracerSessionV1:
"""Load a tracing session."""
return load_default_session()
def _compare_run_with_error(run: Any, expected_run: Any) -> None:
received = run.dict()
received_err = received.pop("error")
expected = expected_run.dict()
expected_err = expected.pop("error")
assert received == expected
assert expected_err in received_err
@freeze_time("2023-01-01")
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_managers = manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
)
compare_run = LLMRun(
uuid=str(run_managers[0].run_id),
parent_uuid=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED_CHAT,
prompts=["Human: "],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
for run_manager in run_managers:
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""
tracer = FakeTracer()
tracer.new_session()
with pytest.raises(TracerException):
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
@freeze_time("2023-01-01")
def test_tracer_multiple_llm_runs() -> None:
"""Test the tracer with multiple runs."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@freeze_time("2023-01-01")
def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
uuid = uuid4()
compare_run = ChainRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chain"},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_end(outputs={}, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
uuid = uuid4()
compare_run = ToolRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "tool"},
tool_input="test",
output="test",
action="{'name': 'tool'}",
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_end("test", run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_run() -> None:
"""Test tracer on a nested run."""
tracer = FakeTracer()
tracer.new_session()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
for _ in range(10):
tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
compare_run = ChainRun(
uuid=str(chain_uuid),
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=1,
child_execution_order=4,
serialized={"name": "chain"},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
child_chain_runs=[],
child_tool_runs=[
ToolRun(
uuid=str(tool_uuid),
parent_uuid=str(chain_uuid),
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=2,
child_execution_order=3,
serialized={"name": "tool"},
tool_input="test",
output="test",
action="{'name': 'tool'}",
session_id=TEST_SESSION_ID,
error=None,
child_chain_runs=[],
child_tool_runs=[],
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid1),
parent_uuid=str(tool_uuid),
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=3,
child_execution_order=3,
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
)
],
),
],
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid2),
parent_uuid=str(chain_uuid),
error=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=4,
child_execution_order=4,
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
),
],
)
assert tracer.runs[0] == compare_run
assert tracer.runs == [compare_run] * 10
@freeze_time("2023-01-01")
def test_tracer_llm_run_on_error() -> None:
"""Test tracer on an LLM run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
prompts=[],
response=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
def test_tracer_chain_run_on_error() -> None:
"""Test tracer on a Chain run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = ChainRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chain"},
inputs={},
outputs=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_error(exception, run_id=uuid)
_compare_run_with_error(tracer.runs[0], compare_run)
@freeze_time("2023-01-01")
def test_tracer_tool_run_on_error() -> None:
"""Test tracer on a Tool run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = ToolRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "tool"},
tool_input="test",
output=None,
action="{'name': 'tool'}",
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_error(exception, run_id=uuid)
_compare_run_with_error(tracer.runs[0], compare_run)
@pytest.fixture
def sample_tracer_session_v1() -> TracerSessionV1:
return TracerSessionV1(id=2, name="Sample session")
@freeze_time("2023-01-01")
def test_convert_run(
lang_chain_tracer_v1: LangChainTracerV1,
sample_tracer_session_v1: TracerSessionV1,
) -> None:
"""Test converting a run to a V1 run."""
llm_run = Run(
id="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
session_id=TEST_SESSION_ID,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]).dict(),
serialized={},
extra={},
run_type="llm",
)
chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
child_execution_order=1,
serialized={},
inputs={},
outputs={},
child_runs=[llm_run],
extra={},
run_type="chain",
)
tool_run = Run(
id="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
inputs={"input": "test"},
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
outputs=None,
serialized={},
child_runs=[],
extra={},
run_type="tool",
)
expected_llm_run = LLMRun(
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
session_id=2,
prompts=[],
response=LLMResult(generations=[[]]),
serialized={},
extra={},
)
expected_chain_run = ChainRun(
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
child_execution_order=1,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
session_id=2,
serialized={},
inputs={},
outputs={},
child_llm_runs=[expected_llm_run],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
expected_tool_run = ToolRun(
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
session_id=2,
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc),
tool_input="test",
action="{}",
serialized={},
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
lang_chain_tracer_v1.session = sample_tracer_session_v1
converted_llm_run = lang_chain_tracer_v1._convert_to_v1_run(llm_run)
converted_chain_run = lang_chain_tracer_v1._convert_to_v1_run(chain_run)
converted_tool_run = lang_chain_tracer_v1._convert_to_v1_run(tool_run)
assert isinstance(converted_llm_run, LLMRun)
assert isinstance(converted_chain_run, ChainRun)
assert isinstance(converted_tool_run, ToolRun)
assert converted_llm_run == expected_llm_run
assert converted_tool_run == expected_tool_run
assert converted_chain_run == expected_chain_run

@ -5,7 +5,17 @@ from langchain_core.tracers.schemas import __all__ as schemas_all
def test_public_api() -> None:
"""Test for changes in the public API."""
expected_all = [
"BaseRun",
"ChainRun",
"LLMRun",
"Run",
"RunTypeEnum",
"ToolRun",
"TracerSession",
"TracerSessionBase",
"TracerSessionV1",
"TracerSessionV1Base",
"TracerSessionV1Create",
]
assert sorted(schemas_all) == expected_all

@ -44,6 +44,7 @@ from langchain_core.callbacks import (
)
from langchain_core.tracers.context import (
collect_runs,
tracing_enabled,
tracing_v2_enabled,
)
from langchain_core.tracers.langchain import LangChainTracer
@ -79,6 +80,7 @@ __all__ = [
"WandbCallbackHandler",
"WhyLabsCallbackHandler",
"get_openai_callback",
"tracing_enabled",
"tracing_v2_enabled",
"collect_runs",
"wandb_tracing_enabled",

@ -30,6 +30,7 @@ from langchain_core.callbacks.manager import (
)
from langchain_core.tracers.context import (
collect_runs,
tracing_enabled,
tracing_v2_enabled,
)
from langchain_core.utils.env import env_var_is_set
@ -52,6 +53,7 @@ __all__ = [
"CallbackManagerForChainGroup",
"AsyncCallbackManager",
"AsyncCallbackManagerForChainGroup",
"tracing_enabled",
"tracing_v2_enabled",
"collect_runs",
"atrace_as_chain_group",

@ -1,6 +1,7 @@
"""Tracers that record execution of LangChain runs."""
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.stdout import (
ConsoleCallbackHandler,
FunctionCallbackHandler,
@ -14,5 +15,6 @@ __all__ = [
"FunctionCallbackHandler",
"LoggingCallbackHandler",
"LangChainTracer",
"LangChainTracerV1",
"WandbTracer",
]

@ -0,0 +1,3 @@
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
__all__ = ["LangChainTracerV1"]

@ -1,7 +1,27 @@
from langchain_core.tracers.schemas import (
BaseRun,
ChainRun,
LLMRun,
Run,
RunTypeEnum,
ToolRun,
TracerSession,
TracerSessionBase,
TracerSessionV1,
TracerSessionV1Base,
TracerSessionV1Create,
)
__all__ = [
"BaseRun",
"ChainRun",
"LLMRun",
"Run",
"RunTypeEnum",
"ToolRun",
"TracerSession",
"TracerSessionBase",
"TracerSessionV1",
"TracerSessionV1Base",
"TracerSessionV1Create",
]

@ -22,11 +22,13 @@ from langchain_core.callbacks.manager import (
from langchain_core.tracers.context import (
collect_runs,
register_configure_hook,
tracing_enabled,
tracing_v2_enabled,
)
from langchain_core.utils.env import env_var_is_set
__all__ = [
"tracing_enabled",
"tracing_v2_enabled",
"collect_runs",
"trace_as_chain_group",

@ -0,0 +1,3 @@
from langchain_core.tracers.langchain_v1 import LangChainTracerV1, get_headers
__all__ = ["get_headers", "LangChainTracerV1"]

@ -1,7 +1,27 @@
from langchain_core.tracers.schemas import (
BaseRun,
ChainRun,
LLMRun,
Run,
RunTypeEnum,
ToolRun,
TracerSession,
TracerSessionBase,
TracerSessionV1,
TracerSessionV1Base,
TracerSessionV1Create,
)
__all__ = [
"RunTypeEnum",
"TracerSessionV1Base",
"TracerSessionV1Create",
"TracerSessionV1",
"TracerSessionBase",
"TracerSession",
"BaseRun",
"LLMRun",
"ChainRun",
"ToolRun",
"Run",
]

@ -25,6 +25,7 @@ EXPECTED_ALL = [
"WandbCallbackHandler",
"WhyLabsCallbackHandler",
"get_openai_callback",
"tracing_enabled",
"tracing_v2_enabled",
"collect_runs",
"wandb_tracing_enabled",

@ -18,6 +18,7 @@ EXPECTED_ALL = [
"CallbackManagerForChainGroup",
"AsyncCallbackManager",
"AsyncCallbackManagerForChainGroup",
"tracing_enabled",
"tracing_v2_enabled",
"collect_runs",
"atrace_as_chain_group",

Loading…
Cancel
Save