|
|
|
@ -44,6 +44,7 @@ class LangChainTracer(BaseTracer):
|
|
|
|
|
example_id: Optional[Union[UUID, str]] = None,
|
|
|
|
|
project_name: Optional[str] = None,
|
|
|
|
|
client: Optional[LangChainPlusClient] = None,
|
|
|
|
|
tags: Optional[List[str]] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Initialize the LangChain tracer."""
|
|
|
|
@ -59,6 +60,7 @@ class LangChainTracer(BaseTracer):
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
|
|
self.client = client or LangChainPlusClient()
|
|
|
|
|
self._futures: Set[Future] = set()
|
|
|
|
|
self.tags = tags or []
|
|
|
|
|
global _TRACERS
|
|
|
|
|
_TRACERS.append(self)
|
|
|
|
|
|
|
|
|
@ -98,11 +100,18 @@ class LangChainTracer(BaseTracer):
|
|
|
|
|
def _persist_run(self, run: Run) -> None:
|
|
|
|
|
"""The Langchain Tracer uses Post/Patch rather than persist."""
|
|
|
|
|
|
|
|
|
|
def _get_tags(self, run: Run) -> List[str]:
|
|
|
|
|
"""Get combined tags for a run."""
|
|
|
|
|
tags = set(run.tags or [])
|
|
|
|
|
tags.update(self.tags or [])
|
|
|
|
|
return list(tags)
|
|
|
|
|
|
|
|
|
|
def _persist_run_single(self, run: Run) -> None:
|
|
|
|
|
"""Persist a run."""
|
|
|
|
|
if run.parent_run_id is None:
|
|
|
|
|
run.reference_example_id = self.example_id
|
|
|
|
|
run_dict = run.dict(exclude={"child_runs"})
|
|
|
|
|
run_dict["tags"] = self._get_tags(run)
|
|
|
|
|
extra = run_dict.get("extra", {})
|
|
|
|
|
extra["runtime"] = get_runtime_environment()
|
|
|
|
|
run_dict["extra"] = extra
|
|
|
|
@ -116,7 +125,9 @@ class LangChainTracer(BaseTracer):
|
|
|
|
|
def _update_run_single(self, run: Run) -> None:
|
|
|
|
|
"""Update a run."""
|
|
|
|
|
try:
|
|
|
|
|
self.client.update_run(run.id, **run.dict())
|
|
|
|
|
run_dict = run.dict()
|
|
|
|
|
run_dict["tags"] = self._get_tags(run)
|
|
|
|
|
self.client.update_run(run.id, **run_dict)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# Errors are swallowed by the thread executor so we need to log them here
|
|
|
|
|
log_error_once("patch", e)
|
|
|
|
|