Fix ClearML callback (#11472)

Handle different field names in dicts/dataframes, fixing the ClearML
callback.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11690/head
eajechiloae 11 months ago committed by GitHub
parent 7ae8b7f065
commit 4ba2c8ba75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,9 @@
from __future__ import annotations
import tempfile import tempfile
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import ( from langchain.callbacks.utils import (
@ -15,6 +17,9 @@ from langchain.callbacks.utils import (
) )
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
if TYPE_CHECKING:
import pandas as pd
def import_clearml() -> Any: def import_clearml() -> Any:
"""Import the clearml python package and raise an error if it is not installed.""" """Import the clearml python package and raise an error if it is not installed."""
@ -173,7 +178,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
resp.update(flatten_dict(serialized)) resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta()) resp.update(self.get_custom_callback_meta())
chain_input = inputs["input"] chain_input = inputs.get("input", inputs.get("human_input"))
if isinstance(chain_input, str): if isinstance(chain_input, str):
input_resp = deepcopy(resp) input_resp = deepcopy(resp)
@ -200,7 +205,12 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.ends += 1 self.ends += 1
resp = self._init_resp() resp = self._init_resp()
resp.update({"action": "on_chain_end", "outputs": outputs["output"]}) resp.update(
{
"action": "on_chain_end",
"outputs": outputs.get("output", outputs.get("text")),
}
)
resp.update(self.get_custom_callback_meta()) resp.update(self.get_custom_callback_meta())
self.on_chain_end_records.append(resp) self.on_chain_end_records.append(resp)
@ -372,16 +382,31 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
return resp return resp
@staticmethod
def _build_llm_df(
base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
) -> pd.DataFrame:
base_df_fields = [field for field in base_df_fields if field in base_df]
rename_map = {
map_entry_k: map_entry_v
for map_entry_k, map_entry_v in rename_map.items()
if map_entry_k in base_df_fields
}
llm_df = base_df[base_df_fields].dropna(axis=1)
if rename_map:
llm_df = llm_df.rename(rename_map, axis=1)
return llm_df
def _create_session_analysis_df(self) -> Any: def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session.""" """Create a dataframe with all the information from the session."""
pd = import_pandas() pd = import_pandas()
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records) on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_input_prompts_df = ( llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
on_llm_start_records_df[["step", "prompts", "name"]] base_df=on_llm_end_records_df,
.dropna(axis=1) base_df_fields=["step", "prompts"]
.rename({"step": "prompt_step"}, axis=1) + (["name"] if "name" in on_llm_end_records_df else ["id"]),
rename_map={"step": "prompt_step"},
) )
complexity_metrics_columns = [] complexity_metrics_columns = []
visualizations_columns: List = [] visualizations_columns: List = []
@ -406,30 +431,20 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"osman", "osman",
] ]
llm_outputs_df = ( llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
on_llm_end_records_df[ on_llm_end_records_df,
[ [
"step", "step",
"text", "text",
"token_usage_total_tokens", "token_usage_total_tokens",
"token_usage_prompt_tokens", "token_usage_prompt_tokens",
"token_usage_completion_tokens", "token_usage_completion_tokens",
]
+ complexity_metrics_columns
+ visualizations_columns
] ]
.dropna(axis=1) + complexity_metrics_columns
.rename({"step": "output_step", "text": "output"}, axis=1) + visualizations_columns,
{"step": "output_step", "text": "output"},
) )
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1) session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
# session_analysis_df["chat_html"] = session_analysis_df[
# ["prompts", "output"]
# ].apply(
# lambda row: construct_html_from_prompt_and_generation(
# row["prompts"], row["output"]
# ),
# axis=1,
# )
return session_analysis_df return session_analysis_df
def flush_tracker( def flush_tracker(

Loading…
Cancel
Save