diff --git a/langchain/callbacks/mlflow_callback.py b/langchain/callbacks/mlflow_callback.py index 8bae7739f4..89f5c6315f 100644 --- a/langchain/callbacks/mlflow_callback.py +++ b/langchain/callbacks/mlflow_callback.py @@ -551,8 +551,18 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"]) on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"]) + llm_input_columns = ["step", "prompt"] + if "name" in on_llm_start_records_df.columns: + llm_input_columns.append("name") + elif "id" in on_llm_start_records_df.columns: + # id is llm class's full import path. For example: + # ["langchain", "llms", "openai", "AzureOpenAI"] + on_llm_start_records_df["name"] = on_llm_start_records_df["id"].apply( + lambda id_: id_[-1] + ) + llm_input_columns.append("name") llm_input_prompts_df = ( - on_llm_start_records_df[["step", "prompt", "name"]] + on_llm_start_records_df[llm_input_columns] .dropna(axis=1) .rename({"step": "prompt_step"}, axis=1) )