|
|
@ -14,21 +14,32 @@ from langchain.callbacks.tracers.schemas import (
|
|
|
|
Run,
|
|
|
|
Run,
|
|
|
|
ToolRun,
|
|
|
|
ToolRun,
|
|
|
|
TracerSession,
|
|
|
|
TracerSession,
|
|
|
|
TracerSessionCreate,
|
|
|
|
TracerSessionBase,
|
|
|
|
|
|
|
|
TracerSessionV2,
|
|
|
|
|
|
|
|
TracerSessionV2Create,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_headers() -> Dict[str, Any]:
|
|
|
|
|
|
|
|
"""Get the headers for the LangChain API."""
|
|
|
|
|
|
|
|
headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
|
|
|
|
|
|
|
if os.getenv("LANGCHAIN_API_KEY"):
|
|
|
|
|
|
|
|
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
|
|
|
|
|
|
|
return headers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_endpoint() -> str:
|
|
|
|
|
|
|
|
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LangChainTracer(BaseTracer):
|
|
|
|
class LangChainTracer(BaseTracer):
|
|
|
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
|
|
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, session_name: str = "default", **kwargs: Any) -> None:
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
|
|
"""Initialize the LangChain tracer."""
|
|
|
|
"""Initialize the LangChain tracer."""
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
|
|
|
|
self._endpoint = _get_endpoint()
|
|
|
|
self._headers: Dict[str, Any] = {"Content-Type": "application/json"}
|
|
|
|
self._headers = _get_headers()
|
|
|
|
if os.getenv("LANGCHAIN_API_KEY"):
|
|
|
|
|
|
|
|
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
|
|
|
|
|
|
|
|
self.session = self.load_session(session_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
|
|
|
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
|
|
|
"""Persist a run."""
|
|
|
|
"""Persist a run."""
|
|
|
@ -48,7 +59,9 @@ class LangChainTracer(BaseTracer):
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logging.warning(f"Failed to persist run: {e}")
|
|
|
|
logging.warning(f"Failed to persist run: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
def _persist_session(self, session_create: TracerSessionCreate) -> TracerSession:
|
|
|
|
def _persist_session(
|
|
|
|
|
|
|
|
self, session_create: TracerSessionBase
|
|
|
|
|
|
|
|
) -> Union[TracerSession, TracerSessionV2]:
|
|
|
|
"""Persist a session."""
|
|
|
|
"""Persist a session."""
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
r = requests.post(
|
|
|
|
r = requests.post(
|
|
|
@ -81,22 +94,89 @@ class LangChainTracer(BaseTracer):
|
|
|
|
self.session = tracer_session
|
|
|
|
self.session = tracer_session
|
|
|
|
return tracer_session
|
|
|
|
return tracer_session
|
|
|
|
|
|
|
|
|
|
|
|
def load_session(self, session_name: str) -> TracerSession:
|
|
|
|
def load_session(self, session_name: str) -> Union[TracerSession, TracerSessionV2]:
|
|
|
|
"""Load a session with the given name from the tracer."""
|
|
|
|
"""Load a session with the given name from the tracer."""
|
|
|
|
return self._load_session(session_name)
|
|
|
|
return self._load_session(session_name)
|
|
|
|
|
|
|
|
|
|
|
|
def load_default_session(self) -> TracerSession:
|
|
|
|
def load_default_session(self) -> Union[TracerSession, TracerSessionV2]:
|
|
|
|
"""Load the default tracing session and set it as the Tracer's session."""
|
|
|
|
"""Load the default tracing session and set it as the Tracer's session."""
|
|
|
|
return self._load_session("default")
|
|
|
|
return self._load_session("default")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_tenant_id() -> Optional[str]:
|
|
|
|
|
|
|
|
"""Get the tenant ID for the LangChain API."""
|
|
|
|
|
|
|
|
tenant_id: Optional[str] = os.getenv("LANGCHAIN_TENANT_ID")
|
|
|
|
|
|
|
|
if tenant_id:
|
|
|
|
|
|
|
|
return tenant_id
|
|
|
|
|
|
|
|
endpoint = _get_endpoint()
|
|
|
|
|
|
|
|
headers = _get_headers()
|
|
|
|
|
|
|
|
response = requests.get(endpoint + "/tenants", headers=headers)
|
|
|
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
|
|
|
tenants: List[Dict[str, Any]] = response.json()
|
|
|
|
|
|
|
|
if not tenants:
|
|
|
|
|
|
|
|
raise ValueError(f"No tenants found for URL {endpoint}")
|
|
|
|
|
|
|
|
return tenants[0]["id"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LangChainTracerV2(LangChainTracer):
|
|
|
|
class LangChainTracerV2(LangChainTracer):
|
|
|
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
|
|
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
|
|
def _convert_run(run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
|
|
|
|
"""Initialize the LangChain tracer."""
|
|
|
|
"""Convert a run to a Run."""
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
self._endpoint = _get_endpoint()
|
|
|
|
|
|
|
|
self._headers = _get_headers()
|
|
|
|
|
|
|
|
self.tenant_id = _get_tenant_id()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_session_create(
|
|
|
|
|
|
|
|
self, name: Optional[str] = None, **kwargs: Any
|
|
|
|
|
|
|
|
) -> TracerSessionBase:
|
|
|
|
|
|
|
|
return TracerSessionV2Create(name=name, extra=kwargs, tenant_id=self.tenant_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _persist_session(self, session_create: TracerSessionBase) -> TracerSessionV2:
|
|
|
|
|
|
|
|
"""Persist a session."""
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
r = requests.post(
|
|
|
|
|
|
|
|
f"{self._endpoint}/sessions",
|
|
|
|
|
|
|
|
data=session_create.json(),
|
|
|
|
|
|
|
|
headers=self._headers,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
session = TracerSessionV2(id=r.json()["id"], **session_create.dict())
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
logging.warning(f"Failed to create session, using default session: {e}")
|
|
|
|
|
|
|
|
session = self.load_session("default")
|
|
|
|
|
|
|
|
return session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_default_query_params(self) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
"""Get the query params for the LangChain API."""
|
|
|
|
|
|
|
|
return {"tenant_id": self.tenant_id}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_session(self, session_name: str) -> TracerSessionV2:
|
|
|
|
|
|
|
|
"""Load a session with the given name from the tracer."""
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
url = f"{self._endpoint}/sessions"
|
|
|
|
|
|
|
|
params = {"tenant_id": self.tenant_id}
|
|
|
|
|
|
|
|
if session_name:
|
|
|
|
|
|
|
|
params["name"] = session_name
|
|
|
|
|
|
|
|
r = requests.get(url, headers=self._headers, params=params)
|
|
|
|
|
|
|
|
tracer_session = TracerSessionV2(**r.json()[0])
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
session_type = "default" if not session_name else session_name
|
|
|
|
|
|
|
|
logging.warning(
|
|
|
|
|
|
|
|
f"Failed to load {session_type} session, using empty session: {e}"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
tracer_session = TracerSessionV2(id=1, tenant_id=self.tenant_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.session = tracer_session
|
|
|
|
|
|
|
|
return tracer_session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_default_session(self) -> TracerSessionV2:
|
|
|
|
|
|
|
|
"""Load the default tracing session and set it as the Tracer's session."""
|
|
|
|
|
|
|
|
return self.load_session("default")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> Run:
|
|
|
|
|
|
|
|
"""Convert a run to a Run."""
|
|
|
|
|
|
|
|
session = self.session or self.load_default_session()
|
|
|
|
inputs: Dict[str, Any] = {}
|
|
|
|
inputs: Dict[str, Any] = {}
|
|
|
|
outputs: Optional[Dict[str, Any]] = None
|
|
|
|
outputs: Optional[Dict[str, Any]] = None
|
|
|
|
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
|
|
|
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
|
|
@ -126,30 +206,30 @@ class LangChainTracerV2(LangChainTracer):
|
|
|
|
|
|
|
|
|
|
|
|
return Run(
|
|
|
|
return Run(
|
|
|
|
id=run.uuid,
|
|
|
|
id=run.uuid,
|
|
|
|
name=run.serialized.get("name"),
|
|
|
|
name=run.serialized.get("name", f"{run_type}-{run.uuid}"),
|
|
|
|
start_time=run.start_time,
|
|
|
|
start_time=run.start_time,
|
|
|
|
end_time=run.end_time,
|
|
|
|
end_time=run.end_time,
|
|
|
|
extra=run.extra,
|
|
|
|
extra=run.extra or {},
|
|
|
|
error=run.error,
|
|
|
|
error=run.error,
|
|
|
|
execution_order=run.execution_order,
|
|
|
|
execution_order=run.execution_order,
|
|
|
|
serialized=run.serialized,
|
|
|
|
serialized=run.serialized,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs=inputs,
|
|
|
|
outputs=outputs,
|
|
|
|
outputs=outputs,
|
|
|
|
session_id=run.session_id,
|
|
|
|
session_id=session.id,
|
|
|
|
run_type=run_type,
|
|
|
|
run_type=run_type,
|
|
|
|
parent_run_id=run.parent_uuid,
|
|
|
|
parent_run_id=run.parent_uuid,
|
|
|
|
child_runs=[LangChainTracerV2._convert_run(child) for child in child_runs],
|
|
|
|
child_runs=[self._convert_run(child) for child in child_runs],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
|
|
|
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
|
|
|
"""Persist a run."""
|
|
|
|
"""Persist a run."""
|
|
|
|
run_create = self._convert_run(run)
|
|
|
|
run_create = self._convert_run(run)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
requests.post(
|
|
|
|
result = requests.post(
|
|
|
|
f"{self._endpoint}/runs",
|
|
|
|
f"{self._endpoint}/runs",
|
|
|
|
data=run_create.json(),
|
|
|
|
data=run_create.json(),
|
|
|
|
headers=self._headers,
|
|
|
|
headers=self._headers,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
result.raise_for_status()
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logging.warning(f"Failed to persist run: {e}")
|
|
|
|
logging.warning(f"Failed to persist run: {e}")
|
|
|
|