Add tags support for langchaintracer (#7207)

This commit is contained in:
William FH 2023-07-05 16:19:04 -07:00 committed by GitHub
parent 75aa408f10
commit 607708a411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 1 deletions

View File

@ -144,6 +144,7 @@ def tracing_v2_enabled(
project_name: Optional[str] = None, project_name: Optional[str] = None,
*, *,
example_id: Optional[Union[str, UUID]] = None, example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
"""Instruct LangChain to log all runs in context to LangSmith. """Instruct LangChain to log all runs in context to LangSmith.
@ -152,6 +153,8 @@ def tracing_v2_enabled(
Defaults to "default". Defaults to "default".
example_id (str or UUID, optional): The ID of the example. example_id (str or UUID, optional): The ID of the example.
Defaults to None. Defaults to None.
tags (List[str], optional): The tags to add to the run.
Defaults to None.
Returns: Returns:
None None
@ -170,6 +173,7 @@ def tracing_v2_enabled(
cb = LangChainTracer( cb = LangChainTracer(
example_id=example_id, example_id=example_id,
project_name=project_name, project_name=project_name,
tags=tags,
) )
tracing_v2_callback_var.set(cb) tracing_v2_callback_var.set(cb)
yield yield

View File

@ -44,6 +44,7 @@ class LangChainTracer(BaseTracer):
example_id: Optional[Union[UUID, str]] = None, example_id: Optional[Union[UUID, str]] = None,
project_name: Optional[str] = None, project_name: Optional[str] = None,
client: Optional[LangChainPlusClient] = None, client: Optional[LangChainPlusClient] = None,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize the LangChain tracer.""" """Initialize the LangChain tracer."""
@ -59,6 +60,7 @@ class LangChainTracer(BaseTracer):
self.executor = ThreadPoolExecutor(max_workers=1) self.executor = ThreadPoolExecutor(max_workers=1)
self.client = client or LangChainPlusClient() self.client = client or LangChainPlusClient()
self._futures: Set[Future] = set() self._futures: Set[Future] = set()
self.tags = tags or []
global _TRACERS global _TRACERS
_TRACERS.append(self) _TRACERS.append(self)
@ -98,11 +100,18 @@ class LangChainTracer(BaseTracer):
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
"""The Langchain Tracer uses Post/Patch rather than persist.""" """The Langchain Tracer uses Post/Patch rather than persist."""
def _get_tags(self, run: Run) -> List[str]:
"""Get combined tags for a run."""
tags = set(run.tags or [])
tags.update(self.tags or [])
return list(tags)
def _persist_run_single(self, run: Run) -> None: def _persist_run_single(self, run: Run) -> None:
"""Persist a run.""" """Persist a run."""
if run.parent_run_id is None: if run.parent_run_id is None:
run.reference_example_id = self.example_id run.reference_example_id = self.example_id
run_dict = run.dict(exclude={"child_runs"}) run_dict = run.dict(exclude={"child_runs"})
run_dict["tags"] = self._get_tags(run)
extra = run_dict.get("extra", {}) extra = run_dict.get("extra", {})
extra["runtime"] = get_runtime_environment() extra["runtime"] = get_runtime_environment()
run_dict["extra"] = extra run_dict["extra"] = extra
@ -116,7 +125,9 @@ class LangChainTracer(BaseTracer):
def _update_run_single(self, run: Run) -> None: def _update_run_single(self, run: Run) -> None:
"""Update a run.""" """Update a run."""
try: try:
self.client.update_run(run.id, **run.dict()) run_dict = run.dict()
run_dict["tags"] = self._get_tags(run)
self.client.update_run(run.id, **run_dict)
except Exception as e: except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here # Errors are swallowed by the thread executor so we need to log them here
log_error_once("patch", e) log_error_once("patch", e)