Fix ClearML callback (#11472)

Handle different field names in dicts/dataframes, fixing the ClearML
callback.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11690/head
eajechiloae 9 months ago committed by GitHub
parent 7ae8b7f065
commit 4ba2c8ba75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,9 @@
from __future__ import annotations
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
@ -15,6 +17,9 @@ from langchain.callbacks.utils import (
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
if TYPE_CHECKING:
import pandas as pd
def import_clearml() -> Any:
"""Import the clearml python package and raise an error if it is not installed."""
@ -173,7 +178,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
chain_input = inputs["input"]
chain_input = inputs.get("input", inputs.get("human_input"))
if isinstance(chain_input, str):
input_resp = deepcopy(resp)
@ -200,7 +205,12 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_chain_end", "outputs": outputs["output"]})
resp.update(
{
"action": "on_chain_end",
"outputs": outputs.get("output", outputs.get("text")),
}
)
resp.update(self.get_custom_callback_meta())
self.on_chain_end_records.append(resp)
@ -372,16 +382,31 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
return resp
@staticmethod
def _build_llm_df(
base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
) -> pd.DataFrame:
base_df_fields = [field for field in base_df_fields if field in base_df]
rename_map = {
map_entry_k: map_entry_v
for map_entry_k, map_entry_v in rename_map.items()
if map_entry_k in base_df_fields
}
llm_df = base_df[base_df_fields].dropna(axis=1)
if rename_map:
llm_df = llm_df.rename(rename_map, axis=1)
return llm_df
def _create_session_analysis_df(self) -> Any:
"""Create a dataframe with all the information from the session."""
pd = import_pandas()
on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_input_prompts_df = (
on_llm_start_records_df[["step", "prompts", "name"]]
.dropna(axis=1)
.rename({"step": "prompt_step"}, axis=1)
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
base_df=on_llm_end_records_df,
base_df_fields=["step", "prompts"]
+ (["name"] if "name" in on_llm_end_records_df else ["id"]),
rename_map={"step": "prompt_step"},
)
complexity_metrics_columns = []
visualizations_columns: List = []
@ -406,30 +431,20 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"osman",
]
llm_outputs_df = (
on_llm_end_records_df[
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
+ complexity_metrics_columns
+ visualizations_columns
llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
on_llm_end_records_df,
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
.dropna(axis=1)
.rename({"step": "output_step", "text": "output"}, axis=1)
+ complexity_metrics_columns
+ visualizations_columns,
{"step": "output_step", "text": "output"},
)
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
# session_analysis_df["chat_html"] = session_analysis_df[
# ["prompts", "output"]
# ].apply(
# lambda row: construct_html_from_prompt_and_generation(
# row["prompts"], row["output"]
# ),
# axis=1,
# )
return session_analysis_df
def flush_tracker(

Loading…
Cancel
Save