mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add tags support for langchaintracer (#7207)
This commit is contained in:
parent
75aa408f10
commit
607708a411
@ -144,6 +144,7 @@ def tracing_v2_enabled(
|
|||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
*,
|
*,
|
||||||
example_id: Optional[Union[str, UUID]] = None,
|
example_id: Optional[Union[str, UUID]] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""Instruct LangChain to log all runs in context to LangSmith.
|
"""Instruct LangChain to log all runs in context to LangSmith.
|
||||||
|
|
||||||
@ -152,6 +153,8 @@ def tracing_v2_enabled(
|
|||||||
Defaults to "default".
|
Defaults to "default".
|
||||||
example_id (str or UUID, optional): The ID of the example.
|
example_id (str or UUID, optional): The ID of the example.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
tags (List[str], optional): The tags to add to the run.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
@ -170,6 +173,7 @@ def tracing_v2_enabled(
|
|||||||
cb = LangChainTracer(
|
cb = LangChainTracer(
|
||||||
example_id=example_id,
|
example_id=example_id,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
|
tags=tags,
|
||||||
)
|
)
|
||||||
tracing_v2_callback_var.set(cb)
|
tracing_v2_callback_var.set(cb)
|
||||||
yield
|
yield
|
||||||
|
@ -44,6 +44,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
example_id: Optional[Union[UUID, str]] = None,
|
example_id: Optional[Union[UUID, str]] = None,
|
||||||
project_name: Optional[str] = None,
|
project_name: Optional[str] = None,
|
||||||
client: Optional[LangChainPlusClient] = None,
|
client: Optional[LangChainPlusClient] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the LangChain tracer."""
|
"""Initialize the LangChain tracer."""
|
||||||
@ -59,6 +60,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||||
self.client = client or LangChainPlusClient()
|
self.client = client or LangChainPlusClient()
|
||||||
self._futures: Set[Future] = set()
|
self._futures: Set[Future] = set()
|
||||||
|
self.tags = tags or []
|
||||||
global _TRACERS
|
global _TRACERS
|
||||||
_TRACERS.append(self)
|
_TRACERS.append(self)
|
||||||
|
|
||||||
@ -98,11 +100,18 @@ class LangChainTracer(BaseTracer):
|
|||||||
def _persist_run(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
"""The Langchain Tracer uses Post/Patch rather than persist."""
|
"""The Langchain Tracer uses Post/Patch rather than persist."""
|
||||||
|
|
||||||
|
def _get_tags(self, run: Run) -> List[str]:
|
||||||
|
"""Get combined tags for a run."""
|
||||||
|
tags = set(run.tags or [])
|
||||||
|
tags.update(self.tags or [])
|
||||||
|
return list(tags)
|
||||||
|
|
||||||
def _persist_run_single(self, run: Run) -> None:
|
def _persist_run_single(self, run: Run) -> None:
|
||||||
"""Persist a run."""
|
"""Persist a run."""
|
||||||
if run.parent_run_id is None:
|
if run.parent_run_id is None:
|
||||||
run.reference_example_id = self.example_id
|
run.reference_example_id = self.example_id
|
||||||
run_dict = run.dict(exclude={"child_runs"})
|
run_dict = run.dict(exclude={"child_runs"})
|
||||||
|
run_dict["tags"] = self._get_tags(run)
|
||||||
extra = run_dict.get("extra", {})
|
extra = run_dict.get("extra", {})
|
||||||
extra["runtime"] = get_runtime_environment()
|
extra["runtime"] = get_runtime_environment()
|
||||||
run_dict["extra"] = extra
|
run_dict["extra"] = extra
|
||||||
@ -116,7 +125,9 @@ class LangChainTracer(BaseTracer):
|
|||||||
def _update_run_single(self, run: Run) -> None:
|
def _update_run_single(self, run: Run) -> None:
|
||||||
"""Update a run."""
|
"""Update a run."""
|
||||||
try:
|
try:
|
||||||
self.client.update_run(run.id, **run.dict())
|
run_dict = run.dict()
|
||||||
|
run_dict["tags"] = self._get_tags(run)
|
||||||
|
self.client.update_run(run.id, **run_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Errors are swallowed by the thread executor so we need to log them here
|
# Errors are swallowed by the thread executor so we need to log them here
|
||||||
log_error_once("patch", e)
|
log_error_once("patch", e)
|
||||||
|
Loading…
Reference in New Issue
Block a user