mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
928cdd57a4
### Refactor the BaseTracer - Remove the 'session' abstraction from the BaseTracer - Rename 'RunV2' object(s) to be called 'Run' objects (Rename previous Run objects to be RunV1 objects) - Ditto for sessions: TracerSession*V2 -> TracerSession* - Remove now deprecated conversion from v1 run objects to v2 run objects in LangChainTracerV2 - Add conversion from v2 run objects to v1 run objects in V1 tracer
135 lines
4.5 KiB
Python
135 lines
4.5 KiB
Python
"""Test Tracer classes."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime
|
|
from typing import Tuple
|
|
from unittest.mock import patch
|
|
from uuid import UUID, uuid4
|
|
|
|
import pytest
|
|
from freezegun import freeze_time
|
|
|
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
|
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
|
|
from langchain.schema import LLMResult
|
|
|
|
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
|
|
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
|
|
|
|
|
|
@pytest.fixture
|
|
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
|
|
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
|
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
|
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
|
tracer = LangChainTracer()
|
|
return tracer
|
|
|
|
|
|
# Mock a sample TracerSession object
|
|
@pytest.fixture
|
|
def sample_tracer_session_v2() -> TracerSession:
|
|
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
|
|
|
|
|
@freeze_time("2023-01-01")
|
|
@pytest.fixture
|
|
def sample_runs() -> Tuple[Run, Run, Run]:
|
|
llm_run = Run(
|
|
id="57a08cc4-73d2-4236-8370-549099d07fad",
|
|
name="llm_run",
|
|
execution_order=1,
|
|
child_execution_order=1,
|
|
parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad",
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
session_id=1,
|
|
inputs={"prompts": []},
|
|
outputs=LLMResult(generations=[[]]).dict(),
|
|
serialized={},
|
|
extra={},
|
|
run_type=RunTypeEnum.llm,
|
|
)
|
|
chain_run = Run(
|
|
id="57a08cc4-73d2-4236-8371-549099d07fad",
|
|
name="chain_run",
|
|
execution_order=1,
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
child_execution_order=1,
|
|
serialized={},
|
|
inputs={},
|
|
outputs={},
|
|
child_runs=[llm_run],
|
|
extra={},
|
|
run_type=RunTypeEnum.chain,
|
|
)
|
|
|
|
tool_run = Run(
|
|
id="57a08cc4-73d2-4236-8372-549099d07fad",
|
|
name="tool_run",
|
|
execution_order=1,
|
|
child_execution_order=1,
|
|
inputs={"input": "test"},
|
|
start_time=datetime.utcnow(),
|
|
end_time=datetime.utcnow(),
|
|
outputs=None,
|
|
serialized={},
|
|
child_runs=[],
|
|
extra={},
|
|
run_type=RunTypeEnum.tool,
|
|
)
|
|
return llm_run, chain_run, tool_run
|
|
|
|
|
|
def test_persist_run(
|
|
lang_chain_tracer_v2: LangChainTracer,
|
|
sample_tracer_session_v2: TracerSession,
|
|
sample_runs: Tuple[Run, Run, Run],
|
|
) -> None:
|
|
"""Test that persist_run method calls requests.post once per method call."""
|
|
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
|
"langchain.callbacks.tracers.langchain.requests.get"
|
|
) as get:
|
|
post.return_value.raise_for_status.return_value = None
|
|
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
|
for run in sample_runs:
|
|
lang_chain_tracer_v2.run_map[str(run.id)] = run
|
|
for run in sample_runs:
|
|
lang_chain_tracer_v2._end_trace(run)
|
|
|
|
assert post.call_count == 3
|
|
assert get.call_count == 0
|
|
|
|
|
|
def test_persist_run_with_example_id(
|
|
lang_chain_tracer_v2: LangChainTracer,
|
|
sample_tracer_session_v2: TracerSession,
|
|
sample_runs: Tuple[Run, Run, Run],
|
|
) -> None:
|
|
"""Test the example ID is assigned only to the parent run and not the children."""
|
|
example_id = uuid4()
|
|
llm_run, chain_run, tool_run = sample_runs
|
|
chain_run.child_runs = [tool_run]
|
|
tool_run.child_runs = [llm_run]
|
|
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
|
"langchain.callbacks.tracers.langchain.requests.get"
|
|
) as get:
|
|
post.return_value.raise_for_status.return_value = None
|
|
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
|
lang_chain_tracer_v2.example_id = example_id
|
|
lang_chain_tracer_v2._persist_run(chain_run)
|
|
|
|
assert post.call_count == 3
|
|
assert get.call_count == 0
|
|
posted_data = [
|
|
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
|
|
]
|
|
assert posted_data[0]["id"] == str(chain_run.id)
|
|
assert posted_data[0]["reference_example_id"] == str(example_id)
|
|
assert posted_data[1]["id"] == str(tool_run.id)
|
|
assert not posted_data[1].get("reference_example_id")
|
|
assert posted_data[2]["id"] == str(llm_run.id)
|
|
assert not posted_data[2].get("reference_example_id")
|