import tempfile from copy import deepcopy from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import Generation, LLMResult from langchain_core.utils import guard_import import langchain_community from langchain_community.callbacks.utils import ( BaseMetadataCallbackHandler, flatten_dict, import_pandas, import_spacy, import_textstat, ) LANGCHAIN_MODEL_NAME = "langchain-model" def import_comet_ml() -> Any: """Import comet_ml and raise an error if it is not installed.""" return guard_import("comet_ml") def _get_experiment( workspace: Optional[str] = None, project_name: Optional[str] = None ) -> Any: comet_ml = import_comet_ml() experiment = comet_ml.Experiment( workspace=workspace, project_name=project_name, ) return experiment def _fetch_text_complexity_metrics(text: str) -> dict: textstat = import_textstat() text_complexity_metrics = { "flesch_reading_ease": textstat.flesch_reading_ease(text), "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), "smog_index": textstat.smog_index(text), "coleman_liau_index": textstat.coleman_liau_index(text), "automated_readability_index": textstat.automated_readability_index(text), "dale_chall_readability_score": textstat.dale_chall_readability_score(text), "difficult_words": textstat.difficult_words(text), "linsear_write_formula": textstat.linsear_write_formula(text), "gunning_fog": textstat.gunning_fog(text), "text_standard": textstat.text_standard(text), "fernandez_huerta": textstat.fernandez_huerta(text), "szigriszt_pazos": textstat.szigriszt_pazos(text), "gutierrez_polini": textstat.gutierrez_polini(text), "crawford": textstat.crawford(text), "gulpease_index": textstat.gulpease_index(text), "osman": textstat.osman(text), } return text_complexity_metrics def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict: pd = import_pandas() metrics_df = pd.DataFrame(metrics) metrics_summary = metrics_df.describe() return metrics_summary.to_dict() class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): """Callback Handler that logs to Comet. Parameters: job_type (str): The type of comet_ml task such as "inference", "testing" or "qc" project_name (str): The comet_ml project name tags (list): Tags to add to the task task_name (str): Name of the comet_ml task visualize (bool): Whether to visualize the run. complexity_metrics (bool): Whether to log complexity metrics stream_logs (bool): Whether to stream callback actions to Comet This handler will utilize the associated callback method and formats the input of each callback function with metadata regarding the state of LLM run, and adds the response to the list of records for both the {method}_records and action. It then logs the response to Comet. """ def __init__( self, task_type: Optional[str] = "inference", workspace: Optional[str] = None, project_name: Optional[str] = None, tags: Optional[Sequence] = None, name: Optional[str] = None, visualizations: Optional[List[str]] = None, complexity_metrics: bool = False, custom_metrics: Optional[Callable] = None, stream_logs: bool = True, ) -> None: """Initialize callback handler.""" self.comet_ml = import_comet_ml() super().__init__() self.task_type = task_type self.workspace = workspace self.project_name = project_name self.tags = tags self.visualizations = visualizations self.complexity_metrics = complexity_metrics self.custom_metrics = custom_metrics self.stream_logs = stream_logs self.temp_dir = tempfile.TemporaryDirectory() self.experiment = _get_experiment(workspace, project_name) self.experiment.log_other("Created from", "langchain") if tags: self.experiment.add_tags(tags) self.name = name if self.name: self.experiment.set_name(self.name) warning = ( "The comet_ml callback is currently in beta and is subject to change " "based on updates to `langchain`. Please report any issues to " "https://github.com/comet-ml/issue-tracking/issues with the tag " "`langchain`." ) self.comet_ml.LOGGER.warning(warning) self.callback_columns: list = [] self.action_records: list = [] self.complexity_metrics = complexity_metrics if self.visualizations: spacy = import_spacy() self.nlp = spacy.load("en_core_web_sm") else: self.nlp = None def _init_resp(self) -> Dict: return {k: None for k in self.callback_columns} def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts.""" self.step += 1 self.llm_starts += 1 self.starts += 1 metadata = self._init_resp() metadata.update({"action": "on_llm_start"}) metadata.update(flatten_dict(serialized)) metadata.update(self.get_custom_callback_meta()) for prompt in prompts: prompt_resp = deepcopy(metadata) prompt_resp["prompts"] = prompt self.on_llm_start_records.append(prompt_resp) self.action_records.append(prompt_resp) if self.stream_logs: self._log_stream(prompt, metadata, self.step) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run when LLM generates a new token.""" self.step += 1 self.llm_streams += 1 resp = self._init_resp() resp.update({"action": "on_llm_new_token", "token": token}) resp.update(self.get_custom_callback_meta()) self.action_records.append(resp) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" self.step += 1 self.llm_ends += 1 self.ends += 1 metadata = self._init_resp() metadata.update({"action": "on_llm_end"}) metadata.update(flatten_dict(response.llm_output or {})) metadata.update(self.get_custom_callback_meta()) output_complexity_metrics = [] output_custom_metrics = [] for prompt_idx, generations in enumerate(response.generations): for gen_idx, generation in enumerate(generations): text = generation.text generation_resp = deepcopy(metadata) generation_resp.update(flatten_dict(generation.dict())) complexity_metrics = self._get_complexity_metrics(text) if complexity_metrics: output_complexity_metrics.append(complexity_metrics) generation_resp.update(complexity_metrics) custom_metrics = self._get_custom_metrics( generation, prompt_idx, gen_idx ) if custom_metrics: output_custom_metrics.append(custom_metrics) generation_resp.update(custom_metrics) if self.stream_logs: self._log_stream(text, metadata, self.step) self.action_records.append(generation_resp) self.on_llm_end_records.append(generation_resp) self._log_text_metrics(output_complexity_metrics, step=self.step) self._log_text_metrics(output_custom_metrics, step=self.step) def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" self.step += 1 self.errors += 1 def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" self.step += 1 self.chain_starts += 1 self.starts += 1 resp = self._init_resp() resp.update({"action": "on_chain_start"}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) for chain_input_key, chain_input_val in inputs.items(): if isinstance(chain_input_val, str): input_resp = deepcopy(resp) if self.stream_logs: self._log_stream(chain_input_val, resp, self.step) input_resp.update({chain_input_key: chain_input_val}) self.action_records.append(input_resp) else: self.comet_ml.LOGGER.warning( f"Unexpected data format provided! " f"Input Value for {chain_input_key} will not be logged" ) def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running.""" self.step += 1 self.chain_ends += 1 self.ends += 1 resp = self._init_resp() resp.update({"action": "on_chain_end"}) resp.update(self.get_custom_callback_meta()) for chain_output_key, chain_output_val in outputs.items(): if isinstance(chain_output_val, str): output_resp = deepcopy(resp) if self.stream_logs: self._log_stream(chain_output_val, resp, self.step) output_resp.update({chain_output_key: chain_output_val}) self.action_records.append(output_resp) else: self.comet_ml.LOGGER.warning( f"Unexpected data format provided! " f"Output Value for {chain_output_key} will not be logged" ) def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" self.step += 1 self.errors += 1 def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: """Run when tool starts running.""" self.step += 1 self.tool_starts += 1 self.starts += 1 resp = self._init_resp() resp.update({"action": "on_tool_start"}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) if self.stream_logs: self._log_stream(input_str, resp, self.step) resp.update({"input_str": input_str}) self.action_records.append(resp) def on_tool_end(self, output: Any, **kwargs: Any) -> None: """Run when tool ends running.""" output = str(output) self.step += 1 self.tool_ends += 1 self.ends += 1 resp = self._init_resp() resp.update({"action": "on_tool_end"}) resp.update(self.get_custom_callback_meta()) if self.stream_logs: self._log_stream(output, resp, self.step) resp.update({"output": output}) self.action_records.append(resp) def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" self.step += 1 self.errors += 1 def on_text(self, text: str, **kwargs: Any) -> None: """ Run when agent is ending. """ self.step += 1 self.text_ctr += 1 resp = self._init_resp() resp.update({"action": "on_text"}) resp.update(self.get_custom_callback_meta()) if self.stream_logs: self._log_stream(text, resp, self.step) resp.update({"text": text}) self.action_records.append(resp) def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """Run when agent ends running.""" self.step += 1 self.agent_ends += 1 self.ends += 1 resp = self._init_resp() output = finish.return_values["output"] log = finish.log resp.update({"action": "on_agent_finish", "log": log}) resp.update(self.get_custom_callback_meta()) if self.stream_logs: self._log_stream(output, resp, self.step) resp.update({"output": output}) self.action_records.append(resp) def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run on agent action.""" self.step += 1 self.tool_starts += 1 self.starts += 1 tool = action.tool tool_input = str(action.tool_input) log = action.log resp = self._init_resp() resp.update({"action": "on_agent_action", "log": log, "tool": tool}) resp.update(self.get_custom_callback_meta()) if self.stream_logs: self._log_stream(tool_input, resp, self.step) resp.update({"tool_input": tool_input}) self.action_records.append(resp) def _get_complexity_metrics(self, text: str) -> dict: """Compute text complexity metrics using textstat. Parameters: text (str): The text to analyze. Returns: (dict): A dictionary containing the complexity metrics. """ resp = {} if self.complexity_metrics: text_complexity_metrics = _fetch_text_complexity_metrics(text) resp.update(text_complexity_metrics) return resp def _get_custom_metrics( self, generation: Generation, prompt_idx: int, gen_idx: int ) -> dict: """Compute Custom Metrics for an LLM Generated Output Args: generation (LLMResult): Output generation from an LLM prompt_idx (int): List index of the input prompt gen_idx (int): List index of the generated output Returns: dict: A dictionary containing the custom metrics. """ resp = {} if self.custom_metrics: custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx) resp.update(custom_metrics) return resp def flush_tracker( self, langchain_asset: Any = None, task_type: Optional[str] = "inference", workspace: Optional[str] = None, project_name: Optional[str] = "comet-langchain-demo", tags: Optional[Sequence] = None, name: Optional[str] = None, visualizations: Optional[List[str]] = None, complexity_metrics: bool = False, custom_metrics: Optional[Callable] = None, finish: bool = False, reset: bool = False, ) -> None: """Flush the tracker and setup the session. Everything after this will be a new table. Args: name: Name of the performed session so far so it is identifiable langchain_asset: The langchain asset to save. finish: Whether to finish the run. Returns: None """ self._log_session(langchain_asset) if langchain_asset: try: self._log_model(langchain_asset) except Exception: self.comet_ml.LOGGER.error( "Failed to export agent or LLM to Comet", exc_info=True, extra={"show_traceback": True}, ) if finish: self.experiment.end() if reset: self._reset( task_type, workspace, project_name, tags, name, visualizations, complexity_metrics, custom_metrics, ) def _log_stream(self, prompt: str, metadata: dict, step: int) -> None: self.experiment.log_text(prompt, metadata=metadata, step=step) def _log_model(self, langchain_asset: Any) -> None: model_parameters = self._get_llm_parameters(langchain_asset) self.experiment.log_parameters(model_parameters, prefix="model") langchain_asset_path = Path(self.temp_dir.name, "model.json") model_name = self.name if self.name else LANGCHAIN_MODEL_NAME try: if hasattr(langchain_asset, "save"): langchain_asset.save(langchain_asset_path) self.experiment.log_model(model_name, str(langchain_asset_path)) except (ValueError, AttributeError, NotImplementedError) as e: if hasattr(langchain_asset, "save_agent"): langchain_asset.save_agent(langchain_asset_path) self.experiment.log_model(model_name, str(langchain_asset_path)) else: self.comet_ml.LOGGER.error( f"{e}" " Could not save Langchain Asset " f"for {langchain_asset.__class__.__name__}" ) def _log_session(self, langchain_asset: Optional[Any] = None) -> None: try: llm_session_df = self._create_session_analysis_dataframe(langchain_asset) # Log the cleaned dataframe as a table self.experiment.log_table("langchain-llm-session.csv", llm_session_df) except Exception: self.comet_ml.LOGGER.warning( "Failed to log session data to Comet", exc_info=True, extra={"show_traceback": True}, ) try: metadata = {"langchain_version": str(langchain_community.__version__)} # Log the langchain low-level records as a JSON file directly self.experiment.log_asset_data( self.action_records, "langchain-action_records.json", metadata=metadata ) except Exception: self.comet_ml.LOGGER.warning( "Failed to log session data to Comet", exc_info=True, extra={"show_traceback": True}, ) try: self._log_visualizations(llm_session_df) except Exception: self.comet_ml.LOGGER.warning( "Failed to log visualizations to Comet", exc_info=True, extra={"show_traceback": True}, ) def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None: if not metrics: return metrics_summary = _summarize_metrics_for_generated_outputs(metrics) for key, value in metrics_summary.items(): self.experiment.log_metrics(value, prefix=key, step=step) def _log_visualizations(self, session_df: Any) -> None: if not (self.visualizations and self.nlp): return spacy = import_spacy() prompts = session_df["prompts"].tolist() outputs = session_df["text"].tolist() for idx, (prompt, output) in enumerate(zip(prompts, outputs)): doc = self.nlp(output) sentence_spans = list(doc.sents) for visualization in self.visualizations: try: html = spacy.displacy.render( sentence_spans, style=visualization, options={"compact": True}, jupyter=False, page=True, ) self.experiment.log_asset_data( html, name=f"langchain-viz-{visualization}-{idx}.html", metadata={"prompt": prompt}, step=idx, ) except Exception as e: self.comet_ml.LOGGER.warning( e, exc_info=True, extra={"show_traceback": True} ) return def _reset( self, task_type: Optional[str] = None, workspace: Optional[str] = None, project_name: Optional[str] = None, tags: Optional[Sequence] = None, name: Optional[str] = None, visualizations: Optional[List[str]] = None, complexity_metrics: bool = False, custom_metrics: Optional[Callable] = None, ) -> None: _task_type = task_type if task_type else self.task_type _workspace = workspace if workspace else self.workspace _project_name = project_name if project_name else self.project_name _tags = tags if tags else self.tags _name = name if name else self.name _visualizations = visualizations if visualizations else self.visualizations _complexity_metrics = ( complexity_metrics if complexity_metrics else self.complexity_metrics ) _custom_metrics = custom_metrics if custom_metrics else self.custom_metrics self.__init__( # type: ignore task_type=_task_type, workspace=_workspace, project_name=_project_name, tags=_tags, name=_name, visualizations=_visualizations, complexity_metrics=_complexity_metrics, custom_metrics=_custom_metrics, ) self.reset_callback_meta() self.temp_dir = tempfile.TemporaryDirectory() def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict: pd = import_pandas() llm_parameters = self._get_llm_parameters(langchain_asset) num_generations_per_prompt = llm_parameters.get("n", 1) llm_start_records_df = pd.DataFrame(self.on_llm_start_records) # Repeat each input row based on the number of outputs generated per prompt llm_start_records_df = llm_start_records_df.loc[ llm_start_records_df.index.repeat(num_generations_per_prompt) ].reset_index(drop=True) llm_end_records_df = pd.DataFrame(self.on_llm_end_records) llm_session_df = pd.merge( llm_start_records_df, llm_end_records_df, left_index=True, right_index=True, suffixes=["_llm_start", "_llm_end"], ) return llm_session_df def _get_llm_parameters(self, langchain_asset: Any = None) -> dict: if not langchain_asset: return {} try: if hasattr(langchain_asset, "agent"): llm_parameters = langchain_asset.agent.llm_chain.llm.dict() elif hasattr(langchain_asset, "llm_chain"): llm_parameters = langchain_asset.llm_chain.llm.dict() elif hasattr(langchain_asset, "llm"): llm_parameters = langchain_asset.llm.dict() else: llm_parameters = langchain_asset.dict() except Exception: return {} return llm_parameters