From 4ba2c8ba75266d073f96a994e856316d369566b7 Mon Sep 17 00:00:00 2001 From: eajechiloae <97950284+eugen-ajechiloae-clearml@users.noreply.github.com> Date: Thu, 12 Oct 2023 03:09:02 +0300 Subject: [PATCH] Fix ClearML callback (#11472) Handle different field names in dicts/dataframes, fixing the ClearML callback. --------- Co-authored-by: Bagatur --- .../langchain/callbacks/clearml_callback.py | 73 +++++++++++-------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/libs/langchain/langchain/callbacks/clearml_callback.py b/libs/langchain/langchain/callbacks/clearml_callback.py index 4d6610bf81..daa495693b 100644 --- a/libs/langchain/langchain/callbacks/clearml_callback.py +++ b/libs/langchain/langchain/callbacks/clearml_callback.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import tempfile from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import ( @@ -15,6 +17,9 @@ from langchain.callbacks.utils import ( ) from langchain.schema import AgentAction, AgentFinish, LLMResult +if TYPE_CHECKING: + import pandas as pd + def import_clearml() -> Any: """Import the clearml python package and raise an error if it is not installed.""" @@ -173,7 +178,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) - chain_input = inputs["input"] + chain_input = inputs.get("input", inputs.get("human_input")) if isinstance(chain_input, str): input_resp = deepcopy(resp) @@ -200,7 +205,12 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.ends += 1 resp = self._init_resp() - resp.update({"action": "on_chain_end", "outputs": outputs["output"]}) + resp.update( + { + "action": "on_chain_end", + "outputs": outputs.get("output", outputs.get("text")), + } + ) resp.update(self.get_custom_callback_meta()) self.on_chain_end_records.append(resp) @@ -372,16 +382,31 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): return resp + @staticmethod + def _build_llm_df( + base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping + ) -> pd.DataFrame: + base_df_fields = [field for field in base_df_fields if field in base_df] + rename_map = { + map_entry_k: map_entry_v + for map_entry_k, map_entry_v in rename_map.items() + if map_entry_k in base_df_fields + } + llm_df = base_df[base_df_fields].dropna(axis=1) + if rename_map: + llm_df = llm_df.rename(rename_map, axis=1) + return llm_df + def _create_session_analysis_df(self) -> Any: """Create a dataframe with all the information from the session.""" pd = import_pandas() - on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records) on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records) - llm_input_prompts_df = ( - on_llm_start_records_df[["step", "prompts", "name"]] - .dropna(axis=1) - .rename({"step": "prompt_step"}, axis=1) + llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df( + base_df=on_llm_end_records_df, + base_df_fields=["step", "prompts"] + + (["name"] if "name" in on_llm_end_records_df else ["id"]), + rename_map={"step": "prompt_step"}, ) complexity_metrics_columns = [] visualizations_columns: List = [] @@ -406,30 +431,20 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): "osman", ] - llm_outputs_df = ( - on_llm_end_records_df[ - [ - "step", - "text", - "token_usage_total_tokens", - "token_usage_prompt_tokens", - "token_usage_completion_tokens", - ] - + complexity_metrics_columns - + visualizations_columns + llm_outputs_df = ClearMLCallbackHandler._build_llm_df( + on_llm_end_records_df, + [ + "step", + "text", + "token_usage_total_tokens", + "token_usage_prompt_tokens", + "token_usage_completion_tokens", ] - .dropna(axis=1) - .rename({"step": "output_step", "text": "output"}, axis=1) + + complexity_metrics_columns + + visualizations_columns, + {"step": "output_step", "text": "output"}, ) session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1) - # session_analysis_df["chat_html"] = session_analysis_df[ - # ["prompts", "output"] - # ].apply( - # lambda row: construct_html_from_prompt_and_generation( - # row["prompts"], row["output"] - # ), - # axis=1, - # ) return session_analysis_df def flush_tracker(