Add tags support for langchaintracer (#7207)

pull/7235/head^2
William FH 1 year ago committed by GitHub
parent 75aa408f10
commit 607708a411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -144,6 +144,7 @@ def tracing_v2_enabled(
project_name: Optional[str] = None,
*,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
) -> Generator[None, None, None]:
"""Instruct LangChain to log all runs in context to LangSmith.
@ -152,6 +153,8 @@ def tracing_v2_enabled(
Defaults to "default".
example_id (str or UUID, optional): The ID of the example.
Defaults to None.
tags (List[str], optional): The tags to add to the run.
Defaults to None.
Returns:
None
@ -170,6 +173,7 @@ def tracing_v2_enabled(
cb = LangChainTracer(
example_id=example_id,
project_name=project_name,
tags=tags,
)
tracing_v2_callback_var.set(cb)
yield

@ -44,6 +44,7 @@ class LangChainTracer(BaseTracer):
example_id: Optional[Union[UUID, str]] = None,
project_name: Optional[str] = None,
client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Initialize the LangChain tracer."""
@ -59,6 +60,7 @@ class LangChainTracer(BaseTracer):
self.executor = ThreadPoolExecutor(max_workers=1)
self.client = client or LangChainPlusClient()
self._futures: Set[Future] = set()
self.tags = tags or []
global _TRACERS
_TRACERS.append(self)
@ -98,11 +100,18 @@ class LangChainTracer(BaseTracer):
def _persist_run(self, run: Run) -> None:
"""The Langchain Tracer uses Post/Patch rather than persist."""
def _get_tags(self, run: Run) -> List[str]:
"""Get combined tags for a run."""
tags = set(run.tags or [])
tags.update(self.tags or [])
return list(tags)
def _persist_run_single(self, run: Run) -> None:
"""Persist a run."""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
run_dict = run.dict(exclude={"child_runs"})
run_dict["tags"] = self._get_tags(run)
extra = run_dict.get("extra", {})
extra["runtime"] = get_runtime_environment()
run_dict["extra"] = extra
@ -116,7 +125,9 @@ class LangChainTracer(BaseTracer):
def _update_run_single(self, run: Run) -> None:
"""Update a run."""
try:
self.client.update_run(run.id, **run.dict())
run_dict = run.dict()
run_dict["tags"] = self._get_tags(run)
self.client.update_run(run.id, **run_dict)
except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here
log_error_once("patch", e)

Loading…
Cancel
Save