@ -1,5 +1,6 @@
import time
import time
from typing import Any , Dict , List
from typing import Any , Dict , List , Optional
from uuid import UUID
from langchain_core . callbacks import BaseCallbackHandler
from langchain_core . callbacks import BaseCallbackHandler
from langchain_core . outputs import LLMResult
from langchain_core . outputs import LLMResult
@ -15,6 +16,11 @@ PROMPT_TOKENS = "prompt_tokens"
COMPLETION_TOKENS = " completion_tokens "
COMPLETION_TOKENS = " completion_tokens "
RUN_ID = " run_id "
RUN_ID = " run_id "
MODEL_NAME = " model_name "
MODEL_NAME = " model_name "
GOOD = " good "
BAD = " bad "
NEUTRAL = " neutral "
SUCCESS = " success "
FAILURE = " failure "
# Default values
# Default values
DEFAULT_MAX_TOKEN = 65536
DEFAULT_MAX_TOKEN = 65536
@ -23,12 +29,20 @@ DEFAULT_MAX_DURATION = 120
# Fiddler specific constants
# Fiddler specific constants
PROMPT = " prompt "
PROMPT = " prompt "
RESPONSE = " response "
RESPONSE = " response "
CONTEXT = " context "
DURATION = " duration "
DURATION = " duration "
FEEDBACK = " feedback "
LLM_STATUS = " llm_status "
FEEDBACK_POSSIBLE_VALUES = [ GOOD , BAD , NEUTRAL ]
# Define a dataset dictionary
# Define a dataset dictionary
_dataset_dict = {
_dataset_dict = {
PROMPT : [ " fiddler " ] * 10 ,
PROMPT : [ " fiddler " ] * 10 ,
RESPONSE : [ " fiddler " ] * 10 ,
RESPONSE : [ " fiddler " ] * 10 ,
CONTEXT : [ " fiddler " ] * 10 ,
FEEDBACK : [ " good " ] * 10 ,
LLM_STATUS : [ " success " ] * 10 ,
MODEL_NAME : [ " fiddler " ] * 10 ,
MODEL_NAME : [ " fiddler " ] * 10 ,
RUN_ID : [ " 123e4567-e89b-12d3-a456-426614174000 " ] * 10 ,
RUN_ID : [ " 123e4567-e89b-12d3-a456-426614174000 " ] * 10 ,
TOTAL_TOKENS : [ 0 , DEFAULT_MAX_TOKEN ] * 5 ,
TOTAL_TOKENS : [ 0 , DEFAULT_MAX_TOKEN ] * 5 ,
@ -83,8 +97,9 @@ class FiddlerCallbackHandler(BaseCallbackHandler):
self . api_key = api_key
self . api_key = api_key
self . _df = self . pd . DataFrame ( _dataset_dict )
self . _df = self . pd . DataFrame ( _dataset_dict )
self . run_id_prompts : Dict [ str , List [ str ] ] = { }
self . run_id_prompts : Dict [ UUID , List [ str ] ] = { }
self . run_id_starttime : Dict [ str , int ] = { }
self . run_id_response : Dict [ UUID , List [ str ] ] = { }
self . run_id_starttime : Dict [ UUID , int ] = { }
# Initialize Fiddler client here
# Initialize Fiddler client here
self . fiddler_client = self . fdl . FiddlerApi ( url , org_id = org , auth_token = api_key )
self . fiddler_client = self . fdl . FiddlerApi ( url , org_id = org , auth_token = api_key )
@ -105,6 +120,17 @@ class FiddlerCallbackHandler(BaseCallbackHandler):
dataset_info = self . fdl . DatasetInfo . from_dataframe (
dataset_info = self . fdl . DatasetInfo . from_dataframe (
self . _df , max_inferred_cardinality = 0
self . _df , max_inferred_cardinality = 0
)
)
# Set feedback column to categorical
for i in range ( len ( dataset_info . columns ) ) :
if dataset_info . columns [ i ] . name == FEEDBACK :
dataset_info . columns [ i ] . data_type = self . fdl . DataType . CATEGORY
dataset_info . columns [ i ] . possible_values = FEEDBACK_POSSIBLE_VALUES
elif dataset_info . columns [ i ] . name == LLM_STATUS :
dataset_info . columns [ i ] . data_type = self . fdl . DataType . CATEGORY
dataset_info . columns [ i ] . possible_values = [ SUCCESS , FAILURE ]
if self . model not in self . fiddler_client . get_dataset_names ( self . project ) :
if self . model not in self . fiddler_client . get_dataset_names ( self . project ) :
print ( # noqa: T201
print ( # noqa: T201
f " adding dataset { self . model } to project { self . project } . "
f " adding dataset { self . model } to project { self . project } . "
@ -128,13 +154,15 @@ class FiddlerCallbackHandler(BaseCallbackHandler):
dataset_info = dataset_info ,
dataset_info = dataset_info ,
dataset_id = " train " ,
dataset_id = " train " ,
model_task = self . fdl . ModelTask . LLM ,
model_task = self . fdl . ModelTask . LLM ,
features = [ PROMPT , RESPONSE ] ,
features = [ PROMPT , RESPONSE , CONTEXT ] ,
target = FEEDBACK ,
metadata_cols = [
metadata_cols = [
RUN_ID ,
RUN_ID ,
TOTAL_TOKENS ,
TOTAL_TOKENS ,
PROMPT_TOKENS ,
PROMPT_TOKENS ,
COMPLETION_TOKENS ,
COMPLETION_TOKENS ,
MODEL_NAME ,
MODEL_NAME ,
DURATION ,
] ,
] ,
custom_features = self . custom_features ,
custom_features = self . custom_features ,
)
)
@ -228,6 +256,42 @@ class FiddlerCallbackHandler(BaseCallbackHandler):
) ,
) ,
]
]
def _publish_events (
self ,
run_id : UUID ,
prompt_responses : List [ str ] ,
duration : int ,
llm_status : str ,
model_name : Optional [ str ] = " " ,
token_usage_dict : Optional [ Dict [ str , Any ] ] = None ,
) - > None :
"""
Publish events to fiddler
"""
prompt_count = len ( self . run_id_prompts [ run_id ] )
df = self . pd . DataFrame (
{
PROMPT : self . run_id_prompts [ run_id ] ,
RESPONSE : prompt_responses ,
RUN_ID : [ str ( run_id ) ] * prompt_count ,
DURATION : [ duration ] * prompt_count ,
LLM_STATUS : [ llm_status ] * prompt_count ,
MODEL_NAME : [ model_name ] * prompt_count ,
}
)
if token_usage_dict :
for key , value in token_usage_dict . items ( ) :
df [ key ] = [ value ] * prompt_count if isinstance ( value , int ) else value
try :
self . fiddler_client . publish_events_batch ( self . project , self . model , df )
except Exception as e :
print ( # noqa: T201
f " Error publishing events to fiddler: { e } . continuing... "
)
def on_llm_start (
def on_llm_start (
self , serialized : Dict [ str , Any ] , prompts : List [ str ] , * * kwargs : Any
self , serialized : Dict [ str , Any ] , prompts : List [ str ] , * * kwargs : Any
) - > Any :
) - > Any :
@ -237,42 +301,36 @@ class FiddlerCallbackHandler(BaseCallbackHandler):
def on_llm_end ( self , response : LLMResult , * * kwargs : Any ) - > None :
def on_llm_end ( self , response : LLMResult , * * kwargs : Any ) - > None :
flattened_llmresult = response . flatten ( )
flattened_llmresult = response . flatten ( )
token_usage_dict = { }
run_id = kwargs [ RUN_ID ]
run_id = kwargs [ RUN_ID ]
run_duration = self . run_id_starttime [ run_id ] - int ( time . time ( ) )
run_duration = self . run_id_starttime [ run_id ] - int ( time . time ( ) )
prompt_responses = [ ]
model_name = " "
model_name = " "
token_usage_dict = { }
if isinstance ( response . llm_output , dict ) :
if isinstance ( response . llm_output , dict ) :
if TOKEN_USAGE in response . llm_output :
token_usage_dict = {
token_usage_dict = response . llm_output [ TOKEN_USAGE ]
k : v
if MODEL_NAME in response . llm_output :
for k , v in response . llm_output . items ( )
model_name = response . llm_output [ MODEL_NAME ]
if k in [ TOTAL_TOKENS , PROMPT_TOKENS , COMPLETION_TOKENS ]
for llmresult in flattened_llmresult :
prompt_responses . append ( llmresult . generations [ 0 ] [ 0 ] . text )
df = self . pd . DataFrame (
{
PROMPT : self . run_id_prompts [ run_id ] ,
RESPONSE : prompt_responses ,
}
}
)
model_name = response . llm_output . get ( MODEL_NAME , " " )
if TOTAL_TOKENS in token_usage_dict :
df [ PROMPT_TOKENS ] = int ( token_usage_dict [ TOTAL_TOKENS ] )
if PROMPT_TOKENS in token_usage_dict :
prompt_responses = [
df [ TOTAL_TOKENS ] = int ( token_usage_dict [ PROMPT_TOKENS ] )
llmresult . generations [ 0 ] [ 0 ] . text for llmresult in flattened_llmresult
]
if COMPLETION_TOKENS in token_usage_dict :
self . _publish_events (
df [ COMPLETION_TOKENS ] = token_usage_dict [ COMPLETION_TOKENS ]
run_id ,
prompt_responses ,
run_duration ,
SUCCESS ,
model_name ,
token_usage_dict ,
)
df [ MODEL_NAME ] = model_name
def on_llm_error ( self , error : BaseException , * * kwargs : Any ) - > None :
df [ RUN_ID ] = str ( run_id )
run_id = kwargs [ RUN_ID ]
df [ DURATION ] = run_duration
d uration = int ( time . time ( ) ) - self . run_id_starttime [ run_id ]
try :
self . _publish_events (
self . fiddler_client . publish_events_batch ( self . project , self . model , df )
run_id , [ " " ] * len ( self . run_id_prompts [ run_id ] ) , duration , FAILURE
except Exception as e :
)
print ( f " Error publishing events to fiddler: { e } . continuing... " ) # noqa: T201