import json import os import shutil import tempfile from copy import deepcopy from typing import Any, Dict, List, Optional from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import LLMResult from langchain_community.callbacks.utils import ( flatten_dict, ) def save_json(data: dict, file_path: str) -> None: """Save dict to local file path. Parameters: data (dict): The dictionary to be saved. file_path (str): Local file path. """ with open(file_path, "w") as outfile: json.dump(data, outfile) class SageMakerCallbackHandler(BaseCallbackHandler): """Callback Handler that logs prompt artifacts and metrics to SageMaker Experiments. Parameters: run (sagemaker.experiments.run.Run): Run object where the experiment is logged. """ def __init__(self, run: Any) -> None: """Initialize callback handler.""" super().__init__() self.run = run self.metrics = { "step": 0, "starts": 0, "ends": 0, "errors": 0, "text_ctr": 0, "chain_starts": 0, "chain_ends": 0, "llm_starts": 0, "llm_ends": 0, "llm_streams": 0, "tool_starts": 0, "tool_ends": 0, "agent_ends": 0, } # Create a temporary directory self.temp_dir = tempfile.mkdtemp() def _reset(self) -> None: for k, v in self.metrics.items(): self.metrics[k] = 0 def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts.""" self.metrics["step"] += 1 self.metrics["llm_starts"] += 1 self.metrics["starts"] += 1 llm_starts = self.metrics["llm_starts"] resp: Dict[str, Any] = {} resp.update({"action": "on_llm_start"}) resp.update(flatten_dict(serialized)) resp.update(self.metrics) for idx, prompt in enumerate(prompts): prompt_resp = deepcopy(resp) prompt_resp["prompt"] = prompt self.jsonf( prompt_resp, self.temp_dir, f"llm_start_{llm_starts}_prompt_{idx}", ) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run when LLM generates a new token.""" self.metrics["step"] += 1 self.metrics["llm_streams"] += 1 llm_streams = self.metrics["llm_streams"] resp: Dict[str, Any] = {} resp.update({"action": "on_llm_new_token", "token": token}) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}") def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" self.metrics["step"] += 1 self.metrics["llm_ends"] += 1 self.metrics["ends"] += 1 llm_ends = self.metrics["llm_ends"] resp: Dict[str, Any] = {} resp.update({"action": "on_llm_end"}) resp.update(flatten_dict(response.llm_output or {})) resp.update(self.metrics) for generations in response.generations: for idx, generation in enumerate(generations): generation_resp = deepcopy(resp) generation_resp.update(flatten_dict(generation.dict())) self.jsonf( resp, self.temp_dir, f"llm_end_{llm_ends}_generation_{idx}", ) def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" self.metrics["step"] += 1 self.metrics["errors"] += 1 def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" self.metrics["step"] += 1 self.metrics["chain_starts"] += 1 self.metrics["starts"] += 1 chain_starts = self.metrics["chain_starts"] resp: Dict[str, Any] = {} resp.update({"action": "on_chain_start"}) resp.update(flatten_dict(serialized)) resp.update(self.metrics) chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()]) input_resp = deepcopy(resp) input_resp["inputs"] = chain_input self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}") def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running.""" self.metrics["step"] += 1 self.metrics["chain_ends"] += 1 self.metrics["ends"] += 1 chain_ends = self.metrics["chain_ends"] resp: Dict[str, Any] = {} chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()]) resp.update({"action": "on_chain_end", "outputs": chain_output}) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}") def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" self.metrics["step"] += 1 self.metrics["errors"] += 1 def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: """Run when tool starts running.""" self.metrics["step"] += 1 self.metrics["tool_starts"] += 1 self.metrics["starts"] += 1 tool_starts = self.metrics["tool_starts"] resp: Dict[str, Any] = {} resp.update({"action": "on_tool_start", "input_str": input_str}) resp.update(flatten_dict(serialized)) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}") def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" self.metrics["step"] += 1 self.metrics["tool_ends"] += 1 self.metrics["ends"] += 1 tool_ends = self.metrics["tool_ends"] resp: Dict[str, Any] = {} resp.update({"action": "on_tool_end", "output": output}) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}") def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" self.metrics["step"] += 1 self.metrics["errors"] += 1 def on_text(self, text: str, **kwargs: Any) -> None: """ Run when agent is ending. """ self.metrics["step"] += 1 self.metrics["text_ctr"] += 1 text_ctr = self.metrics["text_ctr"] resp: Dict[str, Any] = {} resp.update({"action": "on_text", "text": text}) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}") def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """Run when agent ends running.""" self.metrics["step"] += 1 self.metrics["agent_ends"] += 1 self.metrics["ends"] += 1 agent_ends = self.metrics["agent_ends"] resp: Dict[str, Any] = {} resp.update( { "action": "on_agent_finish", "output": finish.return_values["output"], "log": finish.log, } ) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}") def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run on agent action.""" self.metrics["step"] += 1 self.metrics["tool_starts"] += 1 self.metrics["starts"] += 1 tool_starts = self.metrics["tool_starts"] resp: Dict[str, Any] = {} resp.update( { "action": "on_agent_action", "tool": action.tool, "tool_input": action.tool_input, "log": action.log, } ) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}") def jsonf( self, data: Dict[str, Any], data_dir: str, filename: str, is_output: Optional[bool] = True, ) -> None: """To log the input data as json file artifact.""" file_path = os.path.join(data_dir, f"{filename}.json") save_json(data, file_path) self.run.log_file(file_path, name=filename, is_output=is_output) def flush_tracker(self) -> None: """Reset the steps and delete the temporary local directory.""" self._reset() shutil.rmtree(self.temp_dir)