diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 77a83c4123..4dde82c26c 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -144,6 +144,7 @@ def tracing_v2_enabled( project_name: Optional[str] = None, *, example_id: Optional[Union[str, UUID]] = None, + tags: Optional[List[str]] = None, ) -> Generator[None, None, None]: """Instruct LangChain to log all runs in context to LangSmith. @@ -152,6 +153,8 @@ def tracing_v2_enabled( Defaults to "default". example_id (str or UUID, optional): The ID of the example. Defaults to None. + tags (List[str], optional): The tags to add to the run. + Defaults to None. Returns: None @@ -170,6 +173,7 @@ def tracing_v2_enabled( cb = LangChainTracer( example_id=example_id, project_name=project_name, + tags=tags, ) tracing_v2_callback_var.set(cb) yield diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 5759019cc7..165eedea41 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -44,6 +44,7 @@ class LangChainTracer(BaseTracer): example_id: Optional[Union[UUID, str]] = None, project_name: Optional[str] = None, client: Optional[LangChainPlusClient] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Initialize the LangChain tracer.""" @@ -59,6 +60,7 @@ class LangChainTracer(BaseTracer): self.executor = ThreadPoolExecutor(max_workers=1) self.client = client or LangChainPlusClient() self._futures: Set[Future] = set() + self.tags = tags or [] global _TRACERS _TRACERS.append(self) @@ -98,11 +100,18 @@ class LangChainTracer(BaseTracer): def _persist_run(self, run: Run) -> None: """The Langchain Tracer uses Post/Patch rather than persist.""" + def _get_tags(self, run: Run) -> List[str]: + """Get combined tags for a run.""" + tags = set(run.tags or []) + tags.update(self.tags or []) + return list(tags) + def _persist_run_single(self, run: Run) -> None: """Persist a run.""" if run.parent_run_id is None: run.reference_example_id = self.example_id run_dict = run.dict(exclude={"child_runs"}) + run_dict["tags"] = self._get_tags(run) extra = run_dict.get("extra", {}) extra["runtime"] = get_runtime_environment() run_dict["extra"] = extra @@ -116,7 +125,9 @@ class LangChainTracer(BaseTracer): def _update_run_single(self, run: Run) -> None: """Update a run.""" try: - self.client.update_run(run.id, **run.dict()) + run_dict = run.dict() + run_dict["tags"] = self._get_tags(run) + self.client.update_run(run.id, **run_dict) except Exception as e: # Errors are swallowed by the thread executor so we need to log them here log_error_once("patch", e)