mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
core[patch]: BaseTracer
helper method for Run
lookup (#14139)
I observed the same run ID extraction logic is repeated many times in `BaseTracer`. This PR creates a helper method for DRY code.
This commit is contained in:
parent
41ee3be95f
commit
bdb6ae2ed3
@ -92,6 +92,17 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
return parent_run.child_execution_order + 1
|
||||
|
||||
def _get_run(self, run_id: UUID, run_type: str | None = None) -> Run:
|
||||
try:
|
||||
run = self.run_map[str(run_id)]
|
||||
except KeyError as exc:
|
||||
raise TracerException(f"No indexed run ID {run_id}.") from exc
|
||||
if run_type is not None and run.run_type != run_type:
|
||||
raise TracerException(
|
||||
f"Found {run.run_type} run at ID {run_id}, but expected {run_type} run."
|
||||
)
|
||||
return run
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
@ -138,13 +149,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_new_token callback.")
|
||||
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
llm_run = self._get_run(run_id, run_type="llm")
|
||||
event_kwargs: Dict[str, Any] = {"token": token}
|
||||
if chunk:
|
||||
event_kwargs["chunk"] = chunk
|
||||
@ -165,12 +170,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retry callback.")
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None:
|
||||
raise TracerException("No Run found to be traced for on_retry")
|
||||
llm_run = self._get_run(run_id)
|
||||
retry_d: Dict[str, Any] = {
|
||||
"slept": retry_state.idle_for,
|
||||
"attempt": retry_state.attempt_number,
|
||||
@ -196,13 +196,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for an LLM run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_end callback.")
|
||||
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
llm_run = self._get_run(run_id, run_type="llm")
|
||||
llm_run.outputs = response.dict()
|
||||
for i, generations in enumerate(response.generations):
|
||||
for j, generation in enumerate(generations):
|
||||
@ -225,13 +219,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for an LLM run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_error callback.")
|
||||
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException(f"No LLM Run found to be traced for {run_id}")
|
||||
llm_run = self._get_run(run_id, run_type="llm")
|
||||
llm_run.error = repr(error)
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||
@ -286,12 +274,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""End a trace for a chain run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_end callback.")
|
||||
chain_run = self.run_map.get(str(run_id))
|
||||
if chain_run is None:
|
||||
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
||||
|
||||
chain_run = self._get_run(run_id)
|
||||
chain_run.outputs = (
|
||||
outputs if isinstance(outputs, dict) else {"output": outputs}
|
||||
)
|
||||
@ -312,12 +295,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for a chain run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_error callback.")
|
||||
chain_run = self.run_map.get(str(run_id))
|
||||
if chain_run is None:
|
||||
raise TracerException(f"No chain Run found to be traced for {run_id}")
|
||||
|
||||
chain_run = self._get_run(run_id)
|
||||
chain_run.error = repr(error)
|
||||
chain_run.end_time = datetime.utcnow()
|
||||
chain_run.events.append({"name": "error", "time": chain_run.end_time})
|
||||
@ -366,12 +344,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for a tool run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_tool_end callback.")
|
||||
tool_run = self.run_map.get(str(run_id))
|
||||
if tool_run is None or tool_run.run_type != "tool":
|
||||
raise TracerException(f"No tool Run found to be traced for {run_id}")
|
||||
|
||||
tool_run = self._get_run(run_id, run_type="tool")
|
||||
tool_run.outputs = {"output": output}
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
tool_run.events.append({"name": "end", "time": tool_run.end_time})
|
||||
@ -387,12 +360,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Handle an error for a tool run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_tool_error callback.")
|
||||
tool_run = self.run_map.get(str(run_id))
|
||||
if tool_run is None or tool_run.run_type != "tool":
|
||||
raise TracerException(f"No tool Run found to be traced for {run_id}")
|
||||
|
||||
tool_run = self._get_run(run_id, run_type="tool")
|
||||
tool_run.error = repr(error)
|
||||
tool_run.end_time = datetime.utcnow()
|
||||
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
||||
@ -445,12 +413,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run when Retriever errors."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_error callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != "retriever":
|
||||
raise TracerException(f"No retriever Run found to be traced for {run_id}")
|
||||
|
||||
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||
retrieval_run.error = repr(error)
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
||||
@ -462,11 +425,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||
) -> Run:
|
||||
"""Run when Retriever ends running."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_end callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != "retriever":
|
||||
raise TracerException(f"No retriever Run found to be traced for {run_id}")
|
||||
retrieval_run = self._get_run(run_id, run_type="retriever")
|
||||
retrieval_run.outputs = {"documents": documents}
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
|
||||
|
Loading…
Reference in New Issue
Block a user