mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
18af149e91
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
468 lines
14 KiB
Python
468 lines
14 KiB
Python
"""Test Tracer classes."""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from typing import List
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from freezegun import freeze_time
|
|
|
|
from langchain.callbacks.manager import CallbackManager
|
|
from langchain.callbacks.tracers.base import BaseTracer, TracerException
|
|
from langchain.callbacks.tracers.schemas import Run
|
|
from langchain.schema import LLMResult
|
|
|
|
SERIALIZED = {"id": ["llm"]}
|
|
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
|
|
|
|
|
class FakeTracer(BaseTracer):
|
|
"""Fake tracer that records LangChain execution."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize the tracer."""
|
|
super().__init__()
|
|
self.runs: List[Run] = []
|
|
|
|
def _persist_run(self, run: Run) -> None:
|
|
"""Persist a run."""
|
|
self.runs.append(run)
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_llm_run() -> None:
|
|
"""Test tracer on an LLM run."""
|
|
uuid = uuid4()
|
|
compare_run = Run(
|
|
id=uuid,
|
|
parent_run_id=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
child_execution_order=1,
|
|
serialized=SERIALIZED,
|
|
inputs={"prompts": []},
|
|
outputs=LLMResult(generations=[[]]),
|
|
error=None,
|
|
run_type="llm",
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_chat_model_run() -> None:
|
|
"""Test tracer on a Chat Model run."""
|
|
uuid = uuid4()
|
|
compare_run = Run(
|
|
id=str(uuid),
|
|
name="chat_model",
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
child_execution_order=1,
|
|
serialized=SERIALIZED_CHAT,
|
|
inputs=dict(prompts=[""]),
|
|
outputs=LLMResult(generations=[[]]),
|
|
error=None,
|
|
run_type="llm",
|
|
)
|
|
tracer = FakeTracer()
|
|
manager = CallbackManager(handlers=[tracer])
|
|
run_manager = manager.on_chat_model_start(
|
|
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
|
|
)
|
|
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_llm_run_errors_no_start() -> None:
|
|
"""Test tracer on an LLM run without a start."""
|
|
tracer = FakeTracer()
|
|
|
|
with pytest.raises(TracerException):
|
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_multiple_llm_runs() -> None:
|
|
"""Test the tracer with multiple runs."""
|
|
uuid = uuid4()
|
|
compare_run = Run(
|
|
id=uuid,
|
|
name="llm",
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
child_execution_order=1,
|
|
serialized=SERIALIZED,
|
|
inputs=dict(prompts=[]),
|
|
outputs=LLMResult(generations=[[]]),
|
|
error=None,
|
|
run_type="llm",
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
num_runs = 10
|
|
for _ in range(num_runs):
|
|
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
|
|
|
assert tracer.runs == [compare_run] * num_runs
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_chain_run() -> None:
|
|
"""Test tracer on a Chain run."""
|
|
uuid = uuid4()
|
|
compare_run = Run(
|
|
id=str(uuid),
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
child_execution_order=1,
|
|
serialized={"name": "chain"},
|
|
inputs={},
|
|
outputs={},
|
|
error=None,
|
|
run_type="chain",
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
|
tracer.on_chain_end(outputs={}, run_id=uuid)
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_tool_run() -> None:
|
|
"""Test tracer on a Tool run."""
|
|
uuid = uuid4()
|
|
compare_run = Run(
|
|
id=str(uuid),
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
child_execution_order=1,
|
|
serialized={"name": "tool"},
|
|
inputs={"input": "test"},
|
|
outputs={"output": "test"},
|
|
error=None,
|
|
run_type="tool",
|
|
)
|
|
tracer = FakeTracer()
|
|
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
|
tracer.on_tool_end("test", run_id=uuid)
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_nested_run() -> None:
|
|
"""Test tracer on a nested run."""
|
|
tracer = FakeTracer()
|
|
|
|
chain_uuid = uuid4()
|
|
tool_uuid = uuid4()
|
|
llm_uuid1 = uuid4()
|
|
llm_uuid2 = uuid4()
|
|
for _ in range(10):
|
|
tracer.on_chain_start(
|
|
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
|
)
|
|
tracer.on_tool_start(
|
|
serialized={"name": "tool"},
|
|
input_str="test",
|
|
run_id=tool_uuid,
|
|
parent_run_id=chain_uuid,
|
|
)
|
|
tracer.on_llm_start(
|
|
serialized=SERIALIZED,
|
|
prompts=[],
|
|
run_id=llm_uuid1,
|
|
parent_run_id=tool_uuid,
|
|
)
|
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
|
tracer.on_tool_end("test", run_id=tool_uuid)
|
|
tracer.on_llm_start(
|
|
serialized=SERIALIZED,
|
|
prompts=[],
|
|
run_id=llm_uuid2,
|
|
parent_run_id=chain_uuid,
|
|
)
|
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
|
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
|
|
|
compare_run = Run(
|
|
id=str(chain_uuid),
|
|
error=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
child_execution_order=4,
|
|
serialized={"name": "chain"},
|
|
inputs={},
|
|
outputs={},
|
|
run_type="chain",
|
|
child_runs=[
|
|
Run(
|
|
id=tool_uuid,
|
|
parent_run_id=chain_uuid,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=2,
|
|
child_execution_order=3,
|
|
serialized={"name": "tool"},
|
|
inputs=dict(input="test"),
|
|
outputs=dict(output="test"),
|
|
error=None,
|
|
run_type="tool",
|
|
child_runs=[
|
|
Run(
|
|
id=str(llm_uuid1),
|
|
parent_run_id=str(tool_uuid),
|
|
error=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=3,
|
|
child_execution_order=3,
|
|
serialized=SERIALIZED,
|
|
inputs=dict(prompts=[]),
|
|
outputs=LLMResult(generations=[[]]),
|
|
run_type="llm",
|
|
)
|
|
],
|
|
),
|
|
Run(
|
|
id=str(llm_uuid2),
|
|
parent_run_id=str(chain_uuid),
|
|
error=None,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=4,
|
|
child_execution_order=4,
|
|
serialized=SERIALIZED,
|
|
inputs=dict(prompts=[]),
|
|
outputs=LLMResult(generations=[[]]),
|
|
run_type="llm",
|
|
),
|
|
],
|
|
)
|
|
assert tracer.runs[0] == compare_run
|
|
assert tracer.runs == [compare_run] * 10
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_llm_run_on_error() -> None:
|
|
"""Test tracer on an LLM run with an error."""
|
|
exception = Exception("test")
|
|
uuid = uuid4()
|
|
|
|
compare_run = Run(
|
|
id=str(uuid),
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
child_execution_order=1,
|
|
serialized=SERIALIZED,
|
|
inputs=dict(prompts=[]),
|
|
outputs=None,
|
|
error=repr(exception),
|
|
run_type="llm",
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
|
tracer.on_llm_error(exception, run_id=uuid)
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_chain_run_on_error() -> None:
|
|
"""Test tracer on a Chain run with an error."""
|
|
exception = Exception("test")
|
|
uuid = uuid4()
|
|
|
|
compare_run = Run(
|
|
id=str(uuid),
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
child_execution_order=1,
|
|
serialized={"name": "chain"},
|
|
inputs={},
|
|
outputs=None,
|
|
error=repr(exception),
|
|
run_type="chain",
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
|
|
tracer.on_chain_error(exception, run_id=uuid)
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_tool_run_on_error() -> None:
|
|
"""Test tracer on a Tool run with an error."""
|
|
exception = Exception("test")
|
|
uuid = uuid4()
|
|
|
|
compare_run = Run(
|
|
id=str(uuid),
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
child_execution_order=1,
|
|
serialized={"name": "tool"},
|
|
inputs=dict(input="test"),
|
|
outputs=None,
|
|
action="{'name': 'tool'}",
|
|
error=repr(exception),
|
|
run_type="tool",
|
|
)
|
|
tracer = FakeTracer()
|
|
|
|
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
|
|
tracer.on_tool_error(exception, run_id=uuid)
|
|
assert tracer.runs == [compare_run]
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
def test_tracer_nested_runs_on_error() -> None:
|
|
"""Test tracer on a nested run with an error."""
|
|
exception = Exception("test")
|
|
|
|
tracer = FakeTracer()
|
|
chain_uuid = uuid4()
|
|
tool_uuid = uuid4()
|
|
llm_uuid1 = uuid4()
|
|
llm_uuid2 = uuid4()
|
|
llm_uuid3 = uuid4()
|
|
|
|
for _ in range(3):
|
|
tracer.on_chain_start(
|
|
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
|
)
|
|
tracer.on_llm_start(
|
|
serialized=SERIALIZED,
|
|
prompts=[],
|
|
run_id=llm_uuid1,
|
|
parent_run_id=chain_uuid,
|
|
)
|
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
|
tracer.on_llm_start(
|
|
serialized=SERIALIZED,
|
|
prompts=[],
|
|
run_id=llm_uuid2,
|
|
parent_run_id=chain_uuid,
|
|
)
|
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
|
tracer.on_tool_start(
|
|
serialized={"name": "tool"},
|
|
input_str="test",
|
|
run_id=tool_uuid,
|
|
parent_run_id=chain_uuid,
|
|
)
|
|
tracer.on_llm_start(
|
|
serialized=SERIALIZED,
|
|
prompts=[],
|
|
run_id=llm_uuid3,
|
|
parent_run_id=tool_uuid,
|
|
)
|
|
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
|
tracer.on_tool_error(exception, run_id=tool_uuid)
|
|
tracer.on_chain_error(exception, run_id=chain_uuid)
|
|
|
|
compare_run = Run(
|
|
id=str(chain_uuid),
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=1,
|
|
child_execution_order=5,
|
|
serialized={"name": "chain"},
|
|
error=repr(exception),
|
|
inputs={},
|
|
outputs=None,
|
|
run_type="chain",
|
|
child_runs=[
|
|
Run(
|
|
id=str(llm_uuid1),
|
|
parent_run_id=str(chain_uuid),
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=2,
|
|
child_execution_order=2,
|
|
serialized=SERIALIZED,
|
|
error=None,
|
|
inputs=dict(prompts=[]),
|
|
outputs=LLMResult(generations=[[]], llm_output=None),
|
|
run_type="llm",
|
|
),
|
|
Run(
|
|
id=str(llm_uuid2),
|
|
parent_run_id=str(chain_uuid),
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=3,
|
|
child_execution_order=3,
|
|
serialized=SERIALIZED,
|
|
error=None,
|
|
inputs=dict(prompts=[]),
|
|
outputs=LLMResult(generations=[[]], llm_output=None),
|
|
run_type="llm",
|
|
),
|
|
Run(
|
|
id=str(tool_uuid),
|
|
parent_run_id=str(chain_uuid),
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=4,
|
|
child_execution_order=5,
|
|
serialized={"name": "tool"},
|
|
error=repr(exception),
|
|
inputs=dict(input="test"),
|
|
outputs=None,
|
|
action="{'name': 'tool'}",
|
|
child_runs=[
|
|
Run(
|
|
id=str(llm_uuid3),
|
|
parent_run_id=str(tool_uuid),
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
extra={},
|
|
execution_order=5,
|
|
child_execution_order=5,
|
|
serialized=SERIALIZED,
|
|
error=repr(exception),
|
|
inputs=dict(prompts=[]),
|
|
outputs=None,
|
|
run_type="llm",
|
|
)
|
|
],
|
|
run_type="tool",
|
|
),
|
|
],
|
|
)
|
|
assert tracer.runs == [compare_run] * 3
|