mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
57609845df
* add implementations of `BaseCallbackHandler` to support tracing: `SharedTracer` which is thread-safe and `Tracer` which is not and is meant to be used locally. * Tracers persist runs to locally running `langchain-server` Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
531 lines
15 KiB
Python
531 lines
15 KiB
Python
"""Test Tracer classes."""
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
from datetime import datetime
|
|
from typing import List, Optional, Union
|
|
|
|
import pytest
|
|
from freezegun import freeze_time
|
|
|
|
from langchain.callbacks.tracers.base import (
|
|
BaseTracer,
|
|
ChainRun,
|
|
LLMRun,
|
|
SharedTracer,
|
|
ToolRun,
|
|
Tracer,
|
|
TracerException,
|
|
TracerSession,
|
|
)
|
|
from langchain.callbacks.tracers.schemas import TracerSessionCreate
|
|
from langchain.schema import AgentAction, LLMResult
|
|
|
|
TEST_SESSION_ID = 2023
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
|
return ChainRun(
|
|
id=None,
|
|
error=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
serialized={},
|
|
inputs={},
|
|
outputs={},
|
|
session_id=TEST_SESSION_ID,
|
|
child_runs=[
|
|
ToolRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=2,
|
|
serialized={},
|
|
tool_input="test",
|
|
output="test",
|
|
action="action",
|
|
session_id=TEST_SESSION_ID,
|
|
error=None,
|
|
child_runs=[
|
|
LLMRun(
|
|
id=None,
|
|
error=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=3,
|
|
serialized={},
|
|
prompts=[],
|
|
response=LLMResult([[]]),
|
|
session_id=TEST_SESSION_ID,
|
|
)
|
|
],
|
|
),
|
|
LLMRun(
|
|
id=None,
|
|
error=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=4,
|
|
serialized={},
|
|
prompts=[],
|
|
response=LLMResult([[]]),
|
|
session_id=TEST_SESSION_ID,
|
|
),
|
|
],
|
|
)
|
|
|
|
|
|
def _perform_nested_run(tracer: BaseTracer) -> None:
|
|
"""Perform a nested run."""
|
|
tracer.on_chain_start(serialized={}, inputs={})
|
|
tracer.on_tool_start(
|
|
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
|
)
|
|
tracer.on_llm_start(serialized={}, prompts=[])
|
|
tracer.on_llm_end(response=LLMResult([[]]))
|
|
tracer.on_tool_end("test")
|
|
tracer.on_llm_start(serialized={}, prompts=[])
|
|
tracer.on_llm_end(response=LLMResult([[]]))
|
|
tracer.on_chain_end(outputs={})
|
|
|
|
|
|
def _add_child_run(
|
|
parent_run: Union[ChainRun, ToolRun],
|
|
child_run: Union[LLMRun, ChainRun, ToolRun],
|
|
) -> None:
|
|
"""Add child run to a chain run or tool run."""
|
|
parent_run.child_runs.append(child_run)
|
|
|
|
|
|
def _generate_id() -> Optional[Union[int, str]]:
|
|
"""Generate an id for a run."""
|
|
return None
|
|
|
|
|
|
def load_session(session_name: str) -> TracerSession:
|
|
"""Load a tracing session."""
|
|
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
|
|
|
|
|
|
def _persist_session(session: TracerSessionCreate) -> TracerSession:
|
|
"""Persist a tracing session."""
|
|
return TracerSession(id=TEST_SESSION_ID, **session.dict())
|
|
|
|
|
|
def load_default_session() -> TracerSession:
|
|
"""Load a tracing session."""
|
|
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
|
|
|
|
|
|
class FakeTracer(Tracer):
|
|
"""Fake tracer that records LangChain execution."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize the tracer."""
|
|
super().__init__()
|
|
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
|
|
|
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
|
"""Persist a run."""
|
|
self.runs.append(run)
|
|
|
|
def _add_child_run(
|
|
self,
|
|
parent_run: Union[ChainRun, ToolRun],
|
|
child_run: Union[LLMRun, ChainRun, ToolRun],
|
|
) -> None:
|
|
"""Add child run to a chain run or tool run."""
|
|
_add_child_run(parent_run, child_run)
|
|
|
|
def _generate_id(self) -> Optional[Union[int, str]]:
|
|
"""Generate an id for a run."""
|
|
return _generate_id()
|
|
|
|
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
|
"""Persist a tracing session."""
|
|
return _persist_session(session)
|
|
|
|
def load_session(self, session_name: str) -> TracerSession:
|
|
"""Load a tracing session."""
|
|
return load_session(session_name)
|
|
|
|
def load_default_session(self) -> TracerSession:
|
|
"""Load a tracing session."""
|
|
return load_default_session()
|
|
|
|
|
|
class FakeSharedTracer(SharedTracer):
|
|
"""Fake shared tracer that records LangChain execution."""
|
|
|
|
runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
|
|
|
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
|
"""Persist a run."""
|
|
with self._lock:
|
|
self.runs.append(run)
|
|
|
|
def remove_runs(self) -> None:
|
|
"""Remove all runs."""
|
|
with self._lock:
|
|
self.runs = []
|
|
|
|
def _add_child_run(
|
|
self,
|
|
parent_run: Union[ChainRun, ToolRun],
|
|
child_run: Union[LLMRun, ChainRun, ToolRun],
|
|
) -> None:
|
|
"""Add child run to a chain run or tool run."""
|
|
_add_child_run(parent_run, child_run)
|
|
|
|
def _generate_id(self) -> Optional[Union[int, str]]:
|
|
"""Generate an id for a run."""
|
|
return _generate_id()
|
|
|
|
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
|
"""Persist a tracing session."""
|
|
return _persist_session(session)
|
|
|
|
def load_session(self, session_name: str) -> TracerSession:
|
|
"""Load a tracing session."""
|
|
return load_session(session_name)
|
|
|
|
def load_default_session(self) -> TracerSession:
|
|
"""Load a tracing session."""
|
|
return load_default_session()
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_llm_run() -> None:
|
|
"""Test tracer on an LLM run."""
|
|
compare_run = LLMRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
serialized={},
|
|
prompts=[],
|
|
response=LLMResult([[]]),
|
|
session_id=TEST_SESSION_ID,
|
|
error=None,
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.new_session()
|
|
tracer.on_llm_start(serialized={}, prompts=[])
|
|
tracer.on_llm_end(response=LLMResult([[]]))
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_llm_run_errors_no_session() -> None:
|
|
"""Test tracer on an LLM run without a session."""
|
|
tracer = FakeTracer()
|
|
|
|
with pytest.raises(TracerException):
|
|
tracer.on_llm_start(serialized={}, prompts=[])
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_llm_run_errors_no_start() -> None:
|
|
"""Test tracer on an LLM run without a start."""
|
|
tracer = FakeTracer()
|
|
|
|
tracer.new_session()
|
|
with pytest.raises(TracerException):
|
|
tracer.on_llm_end(response=LLMResult([[]]))
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_multiple_llm_runs() -> None:
|
|
"""Test the tracer with multiple runs."""
|
|
compare_run = LLMRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
serialized={},
|
|
prompts=[],
|
|
response=LLMResult([[]]),
|
|
session_id=TEST_SESSION_ID,
|
|
error=None,
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.new_session()
|
|
num_runs = 10
|
|
for _ in range(num_runs):
|
|
tracer.on_llm_start(serialized={}, prompts=[])
|
|
tracer.on_llm_end(response=LLMResult([[]]))
|
|
|
|
assert tracer.runs == [compare_run] * num_runs
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_chain_run() -> None:
|
|
"""Test tracer on a Chain run."""
|
|
compare_run = ChainRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
serialized={},
|
|
inputs={},
|
|
outputs={},
|
|
session_id=TEST_SESSION_ID,
|
|
error=None,
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.new_session()
|
|
tracer.on_chain_start(serialized={}, inputs={})
|
|
tracer.on_chain_end(outputs={})
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_tool_run() -> None:
|
|
"""Test tracer on a Tool run."""
|
|
compare_run = ToolRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
serialized={},
|
|
tool_input="test",
|
|
output="test",
|
|
action="action",
|
|
session_id=TEST_SESSION_ID,
|
|
error=None,
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.new_session()
|
|
tracer.on_tool_start(
|
|
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
|
)
|
|
tracer.on_tool_end("test")
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_nested_run() -> None:
|
|
"""Test tracer on a nested run."""
|
|
tracer = FakeTracer()
|
|
tracer.new_session()
|
|
_perform_nested_run(tracer)
|
|
assert tracer.runs == [_get_compare_run()]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_llm_run_on_error() -> None:
|
|
"""Test tracer on an LLM run with an error."""
|
|
exception = Exception("test")
|
|
|
|
compare_run = LLMRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
serialized={},
|
|
prompts=[],
|
|
response=None,
|
|
session_id=TEST_SESSION_ID,
|
|
error=repr(exception),
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.new_session()
|
|
tracer.on_llm_start(serialized={}, prompts=[])
|
|
tracer.on_llm_error(exception)
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_chain_run_on_error() -> None:
|
|
"""Test tracer on a Chain run with an error."""
|
|
exception = Exception("test")
|
|
|
|
compare_run = ChainRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
serialized={},
|
|
inputs={},
|
|
outputs=None,
|
|
session_id=TEST_SESSION_ID,
|
|
error=repr(exception),
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.new_session()
|
|
tracer.on_chain_start(serialized={}, inputs={})
|
|
tracer.on_chain_error(exception)
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_tool_run_on_error() -> None:
|
|
"""Test tracer on a Tool run with an error."""
|
|
exception = Exception("test")
|
|
|
|
compare_run = ToolRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
serialized={},
|
|
tool_input="test",
|
|
output=None,
|
|
action="action",
|
|
session_id=TEST_SESSION_ID,
|
|
error=repr(exception),
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.new_session()
|
|
tracer.on_tool_start(
|
|
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
|
)
|
|
tracer.on_tool_error(exception)
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_nested_runs_on_error() -> None:
|
|
"""Test tracer on a nested run with an error."""
|
|
exception = Exception("test")
|
|
|
|
tracer = FakeTracer()
|
|
tracer.new_session()
|
|
|
|
for _ in range(3):
|
|
tracer.on_chain_start(serialized={}, inputs={})
|
|
tracer.on_llm_start(serialized={}, prompts=[])
|
|
tracer.on_llm_end(response=LLMResult([[]]))
|
|
tracer.on_llm_start(serialized={}, prompts=[])
|
|
tracer.on_llm_end(response=LLMResult([[]]))
|
|
tracer.on_tool_start(
|
|
serialized={}, action=AgentAction(tool="action", tool_input="test", log="")
|
|
)
|
|
tracer.on_llm_start(serialized={}, prompts=[])
|
|
tracer.on_llm_error(exception)
|
|
tracer.on_tool_error(exception)
|
|
tracer.on_chain_error(exception)
|
|
|
|
compare_run = ChainRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
serialized={},
|
|
session_id=TEST_SESSION_ID,
|
|
error=repr(exception),
|
|
inputs={},
|
|
outputs=None,
|
|
child_runs=[
|
|
LLMRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=2,
|
|
serialized={},
|
|
session_id=TEST_SESSION_ID,
|
|
error=None,
|
|
prompts=[],
|
|
response=LLMResult(generations=[[]], llm_output=None),
|
|
),
|
|
LLMRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=3,
|
|
serialized={},
|
|
session_id=TEST_SESSION_ID,
|
|
error=None,
|
|
prompts=[],
|
|
response=LLMResult(generations=[[]], llm_output=None),
|
|
),
|
|
ToolRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=4,
|
|
serialized={},
|
|
session_id=TEST_SESSION_ID,
|
|
error=repr(exception),
|
|
tool_input="test",
|
|
output=None,
|
|
action="action",
|
|
child_runs=[
|
|
LLMRun(
|
|
id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=5,
|
|
serialized={},
|
|
session_id=TEST_SESSION_ID,
|
|
error=repr(exception),
|
|
prompts=[],
|
|
response=None,
|
|
)
|
|
],
|
|
child_llm_runs=[],
|
|
child_chain_runs=[],
|
|
child_tool_runs=[],
|
|
),
|
|
],
|
|
child_llm_runs=[],
|
|
child_chain_runs=[],
|
|
child_tool_runs=[],
|
|
)
|
|
|
|
assert tracer.runs == [compare_run] * 3
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_shared_tracer_nested_run() -> None:
|
|
"""Test shared tracer on a nested run."""
|
|
tracer = FakeSharedTracer()
|
|
tracer.new_session()
|
|
tracer.remove_runs()
|
|
_perform_nested_run(tracer)
|
|
assert tracer.runs == [_get_compare_run()]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_shared_tracer_nested_run_multithreaded() -> None:
|
|
"""Test shared tracer on a nested run."""
|
|
tracer = FakeSharedTracer()
|
|
tracer.remove_runs()
|
|
tracer.new_session()
|
|
threads = []
|
|
num_threads = 10
|
|
for _ in range(num_threads):
|
|
thread = threading.Thread(target=_perform_nested_run, args=(tracer,))
|
|
thread.start()
|
|
threads.append(thread)
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
assert tracer.runs == [_get_compare_run()] * num_threads
|