diff --git a/libs/community/langchain_community/callbacks/fiddler_callback.py b/libs/community/langchain_community/callbacks/fiddler_callback.py index 6f67cf3f76..dc1f7bbf55 100644 --- a/libs/community/langchain_community/callbacks/fiddler_callback.py +++ b/libs/community/langchain_community/callbacks/fiddler_callback.py @@ -1,5 +1,6 @@ import time -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional +from uuid import UUID from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import LLMResult @@ -15,6 +16,11 @@ PROMPT_TOKENS = "prompt_tokens" COMPLETION_TOKENS = "completion_tokens" RUN_ID = "run_id" MODEL_NAME = "model_name" +GOOD = "good" +BAD = "bad" +NEUTRAL = "neutral" +SUCCESS = "success" +FAILURE = "failure" # Default values DEFAULT_MAX_TOKEN = 65536 @@ -23,12 +29,20 @@ DEFAULT_MAX_DURATION = 120 # Fiddler specific constants PROMPT = "prompt" RESPONSE = "response" +CONTEXT = "context" DURATION = "duration" +FEEDBACK = "feedback" +LLM_STATUS = "llm_status" + +FEEDBACK_POSSIBLE_VALUES = [GOOD, BAD, NEUTRAL] # Define a dataset dictionary _dataset_dict = { PROMPT: ["fiddler"] * 10, RESPONSE: ["fiddler"] * 10, + CONTEXT: ["fiddler"] * 10, + FEEDBACK: ["good"] * 10, + LLM_STATUS: ["success"] * 10, MODEL_NAME: ["fiddler"] * 10, RUN_ID: ["123e4567-e89b-12d3-a456-426614174000"] * 10, TOTAL_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5, @@ -83,8 +97,9 @@ class FiddlerCallbackHandler(BaseCallbackHandler): self.api_key = api_key self._df = self.pd.DataFrame(_dataset_dict) - self.run_id_prompts: Dict[str, List[str]] = {} - self.run_id_starttime: Dict[str, int] = {} + self.run_id_prompts: Dict[UUID, List[str]] = {} + self.run_id_response: Dict[UUID, List[str]] = {} + self.run_id_starttime: Dict[UUID, int] = {} # Initialize Fiddler client here self.fiddler_client = self.fdl.FiddlerApi(url, org_id=org, auth_token=api_key) @@ -105,6 +120,17 @@ class FiddlerCallbackHandler(BaseCallbackHandler): dataset_info = self.fdl.DatasetInfo.from_dataframe( self._df, max_inferred_cardinality=0 ) + + # Set feedback column to categorical + for i in range(len(dataset_info.columns)): + if dataset_info.columns[i].name == FEEDBACK: + dataset_info.columns[i].data_type = self.fdl.DataType.CATEGORY + dataset_info.columns[i].possible_values = FEEDBACK_POSSIBLE_VALUES + + elif dataset_info.columns[i].name == LLM_STATUS: + dataset_info.columns[i].data_type = self.fdl.DataType.CATEGORY + dataset_info.columns[i].possible_values = [SUCCESS, FAILURE] + if self.model not in self.fiddler_client.get_dataset_names(self.project): print( # noqa: T201 f"adding dataset {self.model} to project {self.project}." @@ -128,13 +154,15 @@ class FiddlerCallbackHandler(BaseCallbackHandler): dataset_info=dataset_info, dataset_id="train", model_task=self.fdl.ModelTask.LLM, - features=[PROMPT, RESPONSE], + features=[PROMPT, RESPONSE, CONTEXT], + target=FEEDBACK, metadata_cols=[ RUN_ID, TOTAL_TOKENS, PROMPT_TOKENS, COMPLETION_TOKENS, MODEL_NAME, + DURATION, ], custom_features=self.custom_features, ) @@ -228,6 +256,42 @@ class FiddlerCallbackHandler(BaseCallbackHandler): ), ] + def _publish_events( + self, + run_id: UUID, + prompt_responses: List[str], + duration: int, + llm_status: str, + model_name: Optional[str] = "", + token_usage_dict: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Publish events to fiddler + """ + + prompt_count = len(self.run_id_prompts[run_id]) + df = self.pd.DataFrame( + { + PROMPT: self.run_id_prompts[run_id], + RESPONSE: prompt_responses, + RUN_ID: [str(run_id)] * prompt_count, + DURATION: [duration] * prompt_count, + LLM_STATUS: [llm_status] * prompt_count, + MODEL_NAME: [model_name] * prompt_count, + } + ) + + if token_usage_dict: + for key, value in token_usage_dict.items(): + df[key] = [value] * prompt_count if isinstance(value, int) else value + + try: + self.fiddler_client.publish_events_batch(self.project, self.model, df) + except Exception as e: + print( # noqa: T201 + f"Error publishing events to fiddler: {e}. continuing..." + ) + def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> Any: @@ -237,42 +301,36 @@ class FiddlerCallbackHandler(BaseCallbackHandler): def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: flattened_llmresult = response.flatten() - token_usage_dict = {} run_id = kwargs[RUN_ID] run_duration = self.run_id_starttime[run_id] - int(time.time()) - prompt_responses = [] model_name = "" + token_usage_dict = {} if isinstance(response.llm_output, dict): - if TOKEN_USAGE in response.llm_output: - token_usage_dict = response.llm_output[TOKEN_USAGE] - if MODEL_NAME in response.llm_output: - model_name = response.llm_output[MODEL_NAME] - - for llmresult in flattened_llmresult: - prompt_responses.append(llmresult.generations[0][0].text) - - df = self.pd.DataFrame( - { - PROMPT: self.run_id_prompts[run_id], - RESPONSE: prompt_responses, + token_usage_dict = { + k: v + for k, v in response.llm_output.items() + if k in [TOTAL_TOKENS, PROMPT_TOKENS, COMPLETION_TOKENS] } - ) - - if TOTAL_TOKENS in token_usage_dict: - df[PROMPT_TOKENS] = int(token_usage_dict[TOTAL_TOKENS]) + model_name = response.llm_output.get(MODEL_NAME, "") - if PROMPT_TOKENS in token_usage_dict: - df[TOTAL_TOKENS] = int(token_usage_dict[PROMPT_TOKENS]) + prompt_responses = [ + llmresult.generations[0][0].text for llmresult in flattened_llmresult + ] - if COMPLETION_TOKENS in token_usage_dict: - df[COMPLETION_TOKENS] = token_usage_dict[COMPLETION_TOKENS] + self._publish_events( + run_id, + prompt_responses, + run_duration, + SUCCESS, + model_name, + token_usage_dict, + ) - df[MODEL_NAME] = model_name - df[RUN_ID] = str(run_id) - df[DURATION] = run_duration + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + run_id = kwargs[RUN_ID] + duration = int(time.time()) - self.run_id_starttime[run_id] - try: - self.fiddler_client.publish_events_batch(self.project, self.model, df) - except Exception as e: - print(f"Error publishing events to fiddler: {e}. continuing...") # noqa: T201 + self._publish_events( + run_id, [""] * len(self.run_id_prompts[run_id]), duration, FAILURE + )