core[patch]: BaseTracer helper method for Run lookup (#14139)

I observed the same run ID extraction logic is repeated many times in
`BaseTracer`.

This PR creates a helper method for DRY code.
This commit is contained in:
James Braza 2023-12-02 17:05:50 -05:00 committed by GitHub
parent 41ee3be95f
commit bdb6ae2ed3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -92,6 +92,17 @@ class BaseTracer(BaseCallbackHandler, ABC):
return parent_run.child_execution_order + 1
def _get_run(self, run_id: UUID, run_type: str | None = None) -> Run:
try:
run = self.run_map[str(run_id)]
except KeyError as exc:
raise TracerException(f"No indexed run ID {run_id}.") from exc
if run_type is not None and run.run_type != run_type:
raise TracerException(
f"Found {run.run_type} run at ID {run_id}, but expected {run_type} run."
)
return run
def on_llm_start(
self,
serialized: Dict[str, Any],
@ -138,13 +149,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Run on new LLM token. Only available when streaming is enabled."""
if not run_id:
raise TracerException("No run_id provided for on_llm_new_token callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run = self._get_run(run_id, run_type="llm")
event_kwargs: Dict[str, Any] = {"token": token}
if chunk:
event_kwargs["chunk"] = chunk
@ -165,12 +170,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id: UUID,
**kwargs: Any,
) -> Run:
if not run_id:
raise TracerException("No run_id provided for on_retry callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None:
raise TracerException("No Run found to be traced for on_retry")
llm_run = self._get_run(run_id)
retry_d: Dict[str, Any] = {
"slept": retry_state.idle_for,
"attempt": retry_state.attempt_number,
@ -196,13 +196,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for an LLM run."""
if not run_id:
raise TracerException("No run_id provided for on_llm_end callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run = self._get_run(run_id, run_type="llm")
llm_run.outputs = response.dict()
for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations):
@ -225,13 +219,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Handle an error for an LLM run."""
if not run_id:
raise TracerException("No run_id provided for on_llm_error callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run = self._get_run(run_id, run_type="llm")
llm_run.error = repr(error)
llm_run.end_time = datetime.utcnow()
llm_run.events.append({"name": "error", "time": llm_run.end_time})
@ -286,12 +274,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""End a trace for a chain run."""
if not run_id:
raise TracerException("No run_id provided for on_chain_end callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None:
raise TracerException(f"No chain Run found to be traced for {run_id}")
chain_run = self._get_run(run_id)
chain_run.outputs = (
outputs if isinstance(outputs, dict) else {"output": outputs}
)
@ -312,12 +295,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Handle an error for a chain run."""
if not run_id:
raise TracerException("No run_id provided for on_chain_error callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None:
raise TracerException(f"No chain Run found to be traced for {run_id}")
chain_run = self._get_run(run_id)
chain_run.error = repr(error)
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "error", "time": chain_run.end_time})
@ -366,12 +344,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for a tool run."""
if not run_id:
raise TracerException("No run_id provided for on_tool_end callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != "tool":
raise TracerException(f"No tool Run found to be traced for {run_id}")
tool_run = self._get_run(run_id, run_type="tool")
tool_run.outputs = {"output": output}
tool_run.end_time = datetime.utcnow()
tool_run.events.append({"name": "end", "time": tool_run.end_time})
@ -387,12 +360,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Handle an error for a tool run."""
if not run_id:
raise TracerException("No run_id provided for on_tool_error callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != "tool":
raise TracerException(f"No tool Run found to be traced for {run_id}")
tool_run = self._get_run(run_id, run_type="tool")
tool_run.error = repr(error)
tool_run.end_time = datetime.utcnow()
tool_run.events.append({"name": "error", "time": tool_run.end_time})
@ -445,12 +413,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
**kwargs: Any,
) -> Run:
"""Run when Retriever errors."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_error callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException(f"No retriever Run found to be traced for {run_id}")
retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.error = repr(error)
retrieval_run.end_time = datetime.utcnow()
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
@ -462,11 +425,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> Run:
"""Run when Retriever ends running."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_end callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException(f"No retriever Run found to be traced for {run_id}")
retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.utcnow()
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})