mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Support serialization when inputs/outputs contain generators (#17338)
Pydantic's `dict()` function raises an error here if you pass in a generator. We have a more robust serialization function in lagnsmith that we will use instead.
This commit is contained in:
parent
3a2eb6e12b
commit
7c03cc5ed4
@ -1,4 +1,5 @@
|
||||
"""A Tracer implementation that records to LangChain endpoint."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@ -62,6 +63,14 @@ def _get_executor() -> ThreadPoolExecutor:
|
||||
return _EXECUTOR
|
||||
|
||||
|
||||
def _run_to_dict(run: Run) -> dict:
|
||||
return {
|
||||
**run.dict(exclude={"child_runs", "inputs", "outputs"}),
|
||||
"inputs": run.inputs.copy() if run.inputs is not None else None,
|
||||
"outputs": run.outputs.copy() if run.outputs is not None else None,
|
||||
}
|
||||
|
||||
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""Implementation of the SharedTracer that POSTS to the LangChain endpoint."""
|
||||
|
||||
@ -150,7 +159,7 @@ class LangChainTracer(BaseTracer):
|
||||
|
||||
def _persist_run_single(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
run_dict = run.dict(exclude={"child_runs"})
|
||||
run_dict = _run_to_dict(run)
|
||||
run_dict["tags"] = self._get_tags(run)
|
||||
extra = run_dict.get("extra", {})
|
||||
extra["runtime"] = get_runtime_environment()
|
||||
@ -165,7 +174,7 @@ class LangChainTracer(BaseTracer):
|
||||
def _update_run_single(self, run: Run) -> None:
|
||||
"""Update a run."""
|
||||
try:
|
||||
run_dict = run.dict()
|
||||
run_dict = _run_to_dict(run)
|
||||
run_dict["tags"] = self._get_tags(run)
|
||||
self.client.update_run(run.id, **run_dict)
|
||||
except Exception as e:
|
||||
|
@ -22,6 +22,7 @@ def test_example_id_assignment_threadsafe() -> None:
|
||||
return unittest.mock.MagicMock()
|
||||
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
client.create_run = mock_create_run
|
||||
tracer = LangChainTracer(client=client)
|
||||
old_persist_run_single = tracer._persist_run_single
|
||||
@ -35,6 +36,7 @@ def test_example_id_assignment_threadsafe() -> None:
|
||||
):
|
||||
run_id_1 = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
run_id_2 = UUID("f1f9fa53-8b2f-4742-bdbc-38215f7bd1e1")
|
||||
run_id_3 = UUID("f1f9fa53-8b2f-4742-bdbc-38215f7cd1e1")
|
||||
example_id_1 = UUID("57e42c57-8c79-4d9f-8765-bf6cd3a98055")
|
||||
tracer.example_id = example_id_1
|
||||
tracer.on_llm_start({"name": "example_1"}, ["foo"], run_id=run_id_1)
|
||||
@ -44,9 +46,14 @@ def test_example_id_assignment_threadsafe() -> None:
|
||||
tracer.on_llm_start({"name": "example_2"}, ["foo"], run_id=run_id_2)
|
||||
tracer.on_llm_end(LLMResult(generations=[], llm_output={}), run_id=run_id_2)
|
||||
tracer.example_id = None
|
||||
tracer.on_chain_start(
|
||||
{"name": "no_examples"}, {"inputs": (i for i in range(10))}, run_id=run_id_3
|
||||
)
|
||||
tracer.on_chain_error(ValueError("Foo bar"), run_id=run_id_3)
|
||||
expected_example_ids = {
|
||||
run_id_1: example_id_1,
|
||||
run_id_2: example_id_2,
|
||||
run_id_3: None,
|
||||
}
|
||||
tracer.wait_for_futures()
|
||||
assert example_ids == expected_example_ids
|
||||
|
Loading…
Reference in New Issue
Block a user