mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
1c9ceff503
Hey, I'm Sasha. The SDK engineer from [Comet](https://comet.com). This PR updates the CometTracer class. Added metadata to CometTracerr. From now on, both chains and spans will send it.
136 lines
4.5 KiB
Python
136 lines
4.5 KiB
Python
from types import ModuleType, SimpleNamespace
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict
|
|
|
|
from langchain_core.tracers import BaseTracer
|
|
from langchain_core.utils import guard_import
|
|
|
|
if TYPE_CHECKING:
|
|
from uuid import UUID
|
|
|
|
from comet_llm import Span
|
|
from comet_llm.chains.chain import Chain
|
|
|
|
from langchain_community.callbacks.tracers.schemas import Run
|
|
|
|
|
|
def _get_run_type(run: "Run") -> str:
|
|
if isinstance(run.run_type, str):
|
|
return run.run_type
|
|
elif hasattr(run.run_type, "value"):
|
|
return run.run_type.value
|
|
else:
|
|
return str(run.run_type)
|
|
|
|
|
|
def import_comet_llm_api() -> SimpleNamespace:
|
|
"""Import comet_llm api and raise an error if it is not installed."""
|
|
comet_llm = guard_import("comet_llm")
|
|
comet_llm_chains = guard_import("comet_llm.chains")
|
|
|
|
return SimpleNamespace(
|
|
chain=comet_llm_chains.chain,
|
|
span=comet_llm_chains.span,
|
|
chain_api=comet_llm_chains.api,
|
|
experiment_info=comet_llm.experiment_info,
|
|
flush=comet_llm.flush,
|
|
)
|
|
|
|
|
|
class CometTracer(BaseTracer):
|
|
"""Comet Tracer."""
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
"""Initialize the Comet Tracer."""
|
|
super().__init__(**kwargs)
|
|
self._span_map: Dict["UUID", "Span"] = {}
|
|
"""Map from run id to span."""
|
|
self._chains_map: Dict["UUID", "Chain"] = {}
|
|
"""Map from run id to chain."""
|
|
self._initialize_comet_modules()
|
|
|
|
def _initialize_comet_modules(self) -> None:
|
|
comet_llm_api = import_comet_llm_api()
|
|
self._chain: ModuleType = comet_llm_api.chain
|
|
self._span: ModuleType = comet_llm_api.span
|
|
self._chain_api: ModuleType = comet_llm_api.chain_api
|
|
self._experiment_info: ModuleType = comet_llm_api.experiment_info
|
|
self._flush: Callable[[], None] = comet_llm_api.flush
|
|
|
|
def _persist_run(self, run: "Run") -> None:
|
|
run_dict: Dict[str, Any] = run.dict()
|
|
chain_ = self._chains_map[run.id]
|
|
chain_.set_outputs(outputs=run_dict["outputs"])
|
|
self._chain_api.log_chain(chain_)
|
|
|
|
def _process_start_trace(self, run: "Run") -> None:
|
|
run_dict: Dict[str, Any] = run.dict()
|
|
if not run.parent_run_id:
|
|
# This is the first run, which maps to a chain
|
|
metadata = run_dict["extra"].get("metadata", None)
|
|
|
|
chain_: "Chain" = self._chain.Chain(
|
|
inputs=run_dict["inputs"],
|
|
metadata=metadata,
|
|
experiment_info=self._experiment_info.get(),
|
|
)
|
|
self._chains_map[run.id] = chain_
|
|
else:
|
|
span: "Span" = self._span.Span(
|
|
inputs=run_dict["inputs"],
|
|
category=_get_run_type(run),
|
|
metadata=run_dict["extra"],
|
|
name=run.name,
|
|
)
|
|
span.__api__start__(self._chains_map[run.parent_run_id])
|
|
self._chains_map[run.id] = self._chains_map[run.parent_run_id]
|
|
self._span_map[run.id] = span
|
|
|
|
def _process_end_trace(self, run: "Run") -> None:
|
|
run_dict: Dict[str, Any] = run.dict()
|
|
if not run.parent_run_id:
|
|
pass
|
|
# Langchain will call _persist_run for us
|
|
else:
|
|
span = self._span_map[run.id]
|
|
span.set_outputs(outputs=run_dict["outputs"])
|
|
span.__api__end__()
|
|
|
|
def flush(self) -> None:
|
|
self._flush()
|
|
|
|
def _on_llm_start(self, run: "Run") -> None:
|
|
"""Process the LLM Run upon start."""
|
|
self._process_start_trace(run)
|
|
|
|
def _on_llm_end(self, run: "Run") -> None:
|
|
"""Process the LLM Run."""
|
|
self._process_end_trace(run)
|
|
|
|
def _on_llm_error(self, run: "Run") -> None:
|
|
"""Process the LLM Run upon error."""
|
|
self._process_end_trace(run)
|
|
|
|
def _on_chain_start(self, run: "Run") -> None:
|
|
"""Process the Chain Run upon start."""
|
|
self._process_start_trace(run)
|
|
|
|
def _on_chain_end(self, run: "Run") -> None:
|
|
"""Process the Chain Run."""
|
|
self._process_end_trace(run)
|
|
|
|
def _on_chain_error(self, run: "Run") -> None:
|
|
"""Process the Chain Run upon error."""
|
|
self._process_end_trace(run)
|
|
|
|
def _on_tool_start(self, run: "Run") -> None:
|
|
"""Process the Tool Run upon start."""
|
|
self._process_start_trace(run)
|
|
|
|
def _on_tool_end(self, run: "Run") -> None:
|
|
"""Process the Tool Run."""
|
|
self._process_end_trace(run)
|
|
|
|
def _on_tool_error(self, run: "Run") -> None:
|
|
"""Process the Tool Run upon error."""
|
|
self._process_end_trace(run)
|