@ -1,13 +1,14 @@
# Import the necessary packages for ingestion
import uuid
import uuid
from datetime import datetime
from typing import Any , Dict , List , Optional , Union
from typing import Any , Dict , List , Optional , Union
from langchain . callbacks . base import BaseCallbackHandler
from langchain . callbacks . base import BaseCallbackHandler
from langchain . callbacks . utils import import_pandas
from langchain . schema import AgentAction , AgentFinish , LLMResult
from langchain . schema import AgentAction , AgentFinish , LLMResult
class ArizeCallbackHandler ( BaseCallbackHandler ) :
class ArizeCallbackHandler ( BaseCallbackHandler ) :
""" Callback Handler that logs to Arize platform ."""
""" Callback Handler that logs to Arize ."""
def __init__ (
def __init__ (
self ,
self ,
@ -19,41 +20,29 @@ class ArizeCallbackHandler(BaseCallbackHandler):
""" Initialize callback handler. """
""" Initialize callback handler. """
super ( ) . __init__ ( )
super ( ) . __init__ ( )
# Set the model_id and model_version for the Arize monitoring.
self . model_id = model_id
self . model_id = model_id
self . model_version = model_version
self . model_version = model_version
# Set the SPACE_KEY and API_KEY for the Arize client.
self . space_key = SPACE_KEY
self . space_key = SPACE_KEY
self . api_key = API_KEY
self . api_key = API_KEY
self . prompt_records : List [ str ] = [ ]
# Initialize empty lists to store the prompt/response pairs
self . response_records : List [ str ] = [ ]
# and other necessary data.
self . prediction_ids : List [ str ] = [ ]
self . prompt_records : List = [ ]
self . pred_timestamps : List [ int ] = [ ]
self . response_records : List = [ ]
self . response_embeddings : List [ float ] = [ ]
self . prediction_ids : List = [ ]
self . prompt_embeddings : List [ float ] = [ ]
self . pred_timestamps : List = [ ]
self . response_embeddings : List = [ ]
self . prompt_embeddings : List = [ ]
self . prompt_tokens = 0
self . prompt_tokens = 0
self . completion_tokens = 0
self . completion_tokens = 0
self . total_tokens = 0
self . total_tokens = 0
from arize . api import Client
from arize . pandas . embeddings import EmbeddingGenerator , UseCases
from arize . pandas . embeddings import EmbeddingGenerator , UseCases
from arize . pandas . logger import Client
# Create an embedding generator for generating embeddings
# from prompts and responses.
self . generator = EmbeddingGenerator . from_use_case (
self . generator = EmbeddingGenerator . from_use_case (
use_case = UseCases . NLP . SEQUENCE_CLASSIFICATION ,
use_case = UseCases . NLP . SEQUENCE_CLASSIFICATION ,
model_name = " distilbert-base-uncased " ,
model_name = " distilbert-base-uncased " ,
tokenizer_max_length = 512 ,
tokenizer_max_length = 512 ,
batch_size = 256 ,
batch_size = 256 ,
)
)
# Create an Arize client and check if the SPACE_KEY and API_KEY
# are not set to the default values.
self . arize_client = Client ( space_key = SPACE_KEY , api_key = API_KEY )
self . arize_client = Client ( space_key = SPACE_KEY , api_key = API_KEY )
if SPACE_KEY == " SPACE_KEY " or API_KEY == " API_KEY " :
if SPACE_KEY == " SPACE_KEY " or API_KEY == " API_KEY " :
raise ValueError ( " ❌ CHANGE SPACE AND API KEYS " )
raise ValueError ( " ❌ CHANGE SPACE AND API KEYS " )
@ -63,32 +52,40 @@ class ArizeCallbackHandler(BaseCallbackHandler):
def on_llm_start (
def on_llm_start (
self , serialized : Dict [ str , Any ] , prompts : List [ str ] , * * kwargs : Any
self , serialized : Dict [ str , Any ] , prompts : List [ str ] , * * kwargs : Any
) - > None :
) - > None :
""" Record the prompts when an LLM starts. """
for prompt in prompts :
for prompt in prompts :
self . prompt_records . append ( prompt . replace ( " \n " , " ") )
self . prompt_records . append ( prompt . replace ( " \n " , " " ) )
def on_llm_new_token ( self , token : str , * * kwargs : Any ) - > None :
def on_llm_new_token ( self , token : str , * * kwargs : Any ) - > None :
""" Do nothing when a new token is generated ."""
""" Do nothing ."""
pass
pass
def on_llm_end ( self , response : LLMResult , * * kwargs : Any ) - > None :
def on_llm_end ( self , response : LLMResult , * * kwargs : Any ) - > None :
""" Log data to Arize when an LLM ends. """
pd = import_pandas ( )
from arize . utils . types import (
EmbeddingColumnNames ,
Environments ,
ModelTypes ,
Schema ,
)
import pandas as pd
# Safe check if 'llm_output' and 'token_usage' exist
from arize . utils . types import Embedding , Environments , ModelTypes
if response . llm_output and " token_usage " in response . llm_output :
self . prompt_tokens = response . llm_output [ " token_usage " ] . get (
" prompt_tokens " , 0
)
self . total_tokens = response . llm_output [ " token_usage " ] . get (
" total_tokens " , 0
)
self . completion_tokens = response . llm_output [ " token_usage " ] . get (
" completion_tokens " , 0
)
else :
self . prompt_tokens = (
self . total_tokens
) = self . completion_tokens = 0 # assign default value
# Record token usage of the LLM
if response . llm_output is not None :
self . prompt_tokens = response . llm_output [ " token_usage " ] [ " prompt_tokens " ]
self . total_tokens = response . llm_output [ " token_usage " ] [ " total_tokens " ]
self . completion_tokens = response . llm_output [ " token_usage " ] [
" completion_tokens "
]
i = 0
i = 0
# Go through each prompt response pair and generate embeddings as
# well as timestamp and prediction ids
for generations in response . generations :
for generations in response . generations :
for generation in generations :
for generation in generations :
prompt = self . prompt_records [ i ]
prompt = self . prompt_records [ i ]
@ -97,72 +94,98 @@ class ArizeCallbackHandler(BaseCallbackHandler):
text_col = pd . Series ( prompt . replace ( " \n " , " " ) )
text_col = pd . Series ( prompt . replace ( " \n " , " " ) )
) . reset_index ( drop = True )
) . reset_index ( drop = True )
)
)
generated_text = generation . text . replace ( " \n " , " " )
# Assigning text to response_text instead of response
response_text = generation . text . replace ( " \n " , " " )
response_embedding = pd . Series (
response_embedding = pd . Series (
self . generator . generate_embeddings (
self . generator . generate_embeddings (
text_col = pd . Series ( generation . text . replace ( " \n " , " " ) )
text_col = pd . Series ( generation . text . replace ( " \n " , " " ) )
) . reset_index ( drop = True )
) . reset_index ( drop = True )
)
)
pred_id = str ( uuid . uuid4 ( ) )
str ( uuid . uuid4 ( ) )
pred_timestamp = datetime . now ( ) . timestamp ( )
# Define embedding features for Arize ingestion
embedding_features = {
# Define the columns and data
" prompt_embedding " : Embedding (
columns = [
vector = pd . Series ( prompt_embedding [ 0 ] ) , data = prompt
" prediction_ts " ,
) ,
" response " ,
" response_embedding " : Embedding (
" prompt " ,
vector = pd . Series ( response_embedding [ 0 ] ) , data = generated_text
" response_vector " ,
) ,
" prompt_vector " ,
}
" prompt_token " ,
tags = {
" completion_token " ,
" Prompt Tokens " : self . prompt_tokens ,
" total_token " ,
" Completion Tokens " : self . completion_tokens ,
]
" Total Tokens " : self . total_tokens ,
data = [
}
[
pred_timestamp ,
# Log each prompt response data into arize
response_text ,
future = self . arize_client . log (
prompt ,
prediction_id = pred_id ,
response_embedding [ 0 ] ,
tags = tags ,
prompt_embedding [ 0 ] ,
prediction_label = " 1 " ,
self . prompt_tokens ,
self . total_tokens ,
self . completion_tokens ,
]
]
# Create the DataFrame
df = pd . DataFrame ( data , columns = columns )
# Declare prompt and response columns
prompt_columns = EmbeddingColumnNames (
vector_column_name = " prompt_vector " , data_column_name = " prompt "
)
response_columns = EmbeddingColumnNames (
vector_column_name = " response_vector " , data_column_name = " response "
)
schema = Schema (
timestamp_column_name = " prediction_ts " ,
tag_column_names = [
" prompt_token " ,
" completion_token " ,
" total_token " ,
] ,
prompt_column_names = prompt_columns ,
response_column_names = response_columns ,
)
response_from_arize = self . arize_client . log (
dataframe = df ,
schema = schema ,
model_id = self . model_id ,
model_id = self . model_id ,
model_type = ModelTypes . SCORE_CATEGORICAL ,
model_version = self . model_version ,
model_version = self . model_version ,
model_type = ModelTypes . GENERATIVE_LLM ,
environment = Environments . PRODUCTION ,
environment = Environments . PRODUCTION ,
embedding_features = embedding_features ,
)
)
if response_from_arize . status_code == 200 :
result = future . result ( )
if result . status_code == 200 :
print ( " ✅ Successfully logged data to Arize! " )
print ( " ✅ Successfully logged data to Arize! " )
else :
else :
print (
print ( f ' ❌ Logging failed " { response_from_arize . text } " ' )
f " ❌ Logging failed with status code { result . status_code } "
f ' and message " { result . text } " '
)
i = i + 1
i = i + 1
def on_llm_error (
def on_llm_error (
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
) - > None :
) - > None :
""" Do nothing when LLM outputs an error. """
""" Do nothing ."""
pass
pass
def on_chain_start (
def on_chain_start (
self , serialized : Dict [ str , Any ] , inputs : Dict [ str , Any ] , * * kwargs : Any
self , serialized : Dict [ str , Any ] , inputs : Dict [ str , Any ] , * * kwargs : Any
) - > None :
) - > None :
""" Do nothing when LLM chain starts. """
pass
pass
def on_chain_end ( self , outputs : Dict [ str , Any ] , * * kwargs : Any ) - > None :
def on_chain_end ( self , outputs : Dict [ str , Any ] , * * kwargs : Any ) - > None :
""" Do nothing when LLM chain ends ."""
""" Do nothing ."""
pass
pass
def on_chain_error (
def on_chain_error (
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
) - > None :
) - > None :
""" Do nothing when LLM chain outputs an error ."""
""" Do nothing ."""
pass
pass
def on_tool_start (
def on_tool_start (
@ -171,11 +194,10 @@ class ArizeCallbackHandler(BaseCallbackHandler):
input_str : str ,
input_str : str ,
* * kwargs : Any ,
* * kwargs : Any ,
) - > None :
) - > None :
""" Do nothing when tool starts. """
pass
pass
def on_agent_action ( self , action : AgentAction , * * kwargs : Any ) - > Any :
def on_agent_action ( self , action : AgentAction , * * kwargs : Any ) - > Any :
""" Do nothing when agent takes a specific action ."""
""" Do nothing ."""
pass
pass
def on_tool_end (
def on_tool_end (
@ -185,19 +207,15 @@ class ArizeCallbackHandler(BaseCallbackHandler):
llm_prefix : Optional [ str ] = None ,
llm_prefix : Optional [ str ] = None ,
* * kwargs : Any ,
* * kwargs : Any ,
) - > None :
) - > None :
""" Do nothing when tool ends. """
pass
pass
def on_tool_error (
def on_tool_error (
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
) - > None :
) - > None :
""" Do nothing when tool outputs an error. """
pass
pass
def on_text ( self , text : str , * * kwargs : Any ) - > None :
def on_text ( self , text : str , * * kwargs : Any ) - > None :
""" Do nothing """
pass
pass
def on_agent_finish ( self , finish : AgentFinish , * * kwargs : Any ) - > None :
def on_agent_finish ( self , finish : AgentFinish , * * kwargs : Any ) - > None :
""" Do nothing """
pass
pass