diff --git a/langchain/callbacks/arize_callback.py b/langchain/callbacks/arize_callback.py index a93dddfe..7e1196e7 100644 --- a/langchain/callbacks/arize_callback.py +++ b/langchain/callbacks/arize_callback.py @@ -1,13 +1,14 @@ -# Import the necessary packages for ingestion import uuid +from datetime import datetime from typing import Any, Dict, List, Optional, Union from langchain.callbacks.base import BaseCallbackHandler +from langchain.callbacks.utils import import_pandas from langchain.schema import AgentAction, AgentFinish, LLMResult class ArizeCallbackHandler(BaseCallbackHandler): - """Callback Handler that logs to Arize platform.""" + """Callback Handler that logs to Arize.""" def __init__( self, @@ -19,41 +20,29 @@ class ArizeCallbackHandler(BaseCallbackHandler): """Initialize callback handler.""" super().__init__() - - # Set the model_id and model_version for the Arize monitoring. self.model_id = model_id self.model_version = model_version - - # Set the SPACE_KEY and API_KEY for the Arize client. self.space_key = SPACE_KEY self.api_key = API_KEY - - # Initialize empty lists to store the prompt/response pairs - # and other necessary data. - self.prompt_records: List = [] - self.response_records: List = [] - self.prediction_ids: List = [] - self.pred_timestamps: List = [] - self.response_embeddings: List = [] - self.prompt_embeddings: List = [] + self.prompt_records: List[str] = [] + self.response_records: List[str] = [] + self.prediction_ids: List[str] = [] + self.pred_timestamps: List[int] = [] + self.response_embeddings: List[float] = [] + self.prompt_embeddings: List[float] = [] self.prompt_tokens = 0 self.completion_tokens = 0 self.total_tokens = 0 - from arize.api import Client from arize.pandas.embeddings import EmbeddingGenerator, UseCases + from arize.pandas.logger import Client - # Create an embedding generator for generating embeddings - # from prompts and responses. self.generator = EmbeddingGenerator.from_use_case( use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION, model_name="distilbert-base-uncased", tokenizer_max_length=512, batch_size=256, ) - - # Create an Arize client and check if the SPACE_KEY and API_KEY - # are not set to the default values. self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY) if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY": raise ValueError("❌ CHANGE SPACE AND API KEYS") @@ -63,32 +52,40 @@ class ArizeCallbackHandler(BaseCallbackHandler): def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: - """Record the prompts when an LLM starts.""" - for prompt in prompts: - self.prompt_records.append(prompt.replace("\n", " ")) + self.prompt_records.append(prompt.replace("\n", "")) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing when a new token is generated.""" + """Do nothing.""" pass def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Log data to Arize when an LLM ends.""" - - import pandas as pd - from arize.utils.types import Embedding, Environments, ModelTypes - - # Record token usage of the LLM - if response.llm_output is not None: - self.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"] - self.total_tokens = response.llm_output["token_usage"]["total_tokens"] - self.completion_tokens = response.llm_output["token_usage"][ - "completion_tokens" - ] + pd = import_pandas() + from arize.utils.types import ( + EmbeddingColumnNames, + Environments, + ModelTypes, + Schema, + ) + + # Safe check if 'llm_output' and 'token_usage' exist + if response.llm_output and "token_usage" in response.llm_output: + self.prompt_tokens = response.llm_output["token_usage"].get( + "prompt_tokens", 0 + ) + self.total_tokens = response.llm_output["token_usage"].get( + "total_tokens", 0 + ) + self.completion_tokens = response.llm_output["token_usage"].get( + "completion_tokens", 0 + ) + else: + self.prompt_tokens = ( + self.total_tokens + ) = self.completion_tokens = 0 # assign default value + i = 0 - # Go through each prompt response pair and generate embeddings as - # well as timestamp and prediction ids for generations in response.generations: for generation in generations: prompt = self.prompt_records[i] @@ -97,72 +94,98 @@ class ArizeCallbackHandler(BaseCallbackHandler): text_col=pd.Series(prompt.replace("\n", " ")) ).reset_index(drop=True) ) - generated_text = generation.text.replace("\n", " ") + + # Assigning text to response_text instead of response + response_text = generation.text.replace("\n", " ") response_embedding = pd.Series( self.generator.generate_embeddings( text_col=pd.Series(generation.text.replace("\n", " ")) ).reset_index(drop=True) ) - pred_id = str(uuid.uuid4()) - - # Define embedding features for Arize ingestion - embedding_features = { - "prompt_embedding": Embedding( - vector=pd.Series(prompt_embedding[0]), data=prompt - ), - "response_embedding": Embedding( - vector=pd.Series(response_embedding[0]), data=generated_text - ), - } - tags = { - "Prompt Tokens": self.prompt_tokens, - "Completion Tokens": self.completion_tokens, - "Total Tokens": self.total_tokens, - } - - # Log each prompt response data into arize - future = self.arize_client.log( - prediction_id=pred_id, - tags=tags, - prediction_label="1", + str(uuid.uuid4()) + pred_timestamp = datetime.now().timestamp() + + # Define the columns and data + columns = [ + "prediction_ts", + "response", + "prompt", + "response_vector", + "prompt_vector", + "prompt_token", + "completion_token", + "total_token", + ] + data = [ + [ + pred_timestamp, + response_text, + prompt, + response_embedding[0], + prompt_embedding[0], + self.prompt_tokens, + self.total_tokens, + self.completion_tokens, + ] + ] + + # Create the DataFrame + df = pd.DataFrame(data, columns=columns) + + # Declare prompt and response columns + prompt_columns = EmbeddingColumnNames( + vector_column_name="prompt_vector", data_column_name="prompt" + ) + + response_columns = EmbeddingColumnNames( + vector_column_name="response_vector", data_column_name="response" + ) + + schema = Schema( + timestamp_column_name="prediction_ts", + tag_column_names=[ + "prompt_token", + "completion_token", + "total_token", + ], + prompt_column_names=prompt_columns, + response_column_names=response_columns, + ) + + response_from_arize = self.arize_client.log( + dataframe=df, + schema=schema, model_id=self.model_id, - model_type=ModelTypes.SCORE_CATEGORICAL, model_version=self.model_version, + model_type=ModelTypes.GENERATIVE_LLM, environment=Environments.PRODUCTION, - embedding_features=embedding_features, ) - - result = future.result() - if result.status_code == 200: + if response_from_arize.status_code == 200: print("✅ Successfully logged data to Arize!") else: - print( - f"❌ Logging failed with status code {result.status_code}" - f' and message "{result.text}"' - ) + print(f'❌ Logging failed "{response_from_arize.text}"') i = i + 1 def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: - """Do nothing when LLM outputs an error.""" + """Do nothing.""" pass def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: - """Do nothing when LLM chain starts.""" pass def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Do nothing when LLM chain ends.""" + """Do nothing.""" pass def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: - """Do nothing when LLM chain outputs an error.""" + """Do nothing.""" pass def on_tool_start( @@ -171,11 +194,10 @@ class ArizeCallbackHandler(BaseCallbackHandler): input_str: str, **kwargs: Any, ) -> None: - """Do nothing when tool starts.""" pass def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Do nothing when agent takes a specific action.""" + """Do nothing.""" pass def on_tool_end( @@ -185,19 +207,15 @@ class ArizeCallbackHandler(BaseCallbackHandler): llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: - """Do nothing when tool ends.""" pass def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: - """Do nothing when tool outputs an error.""" pass def on_text(self, text: str, **kwargs: Any) -> None: - """Do nothing""" pass def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Do nothing""" pass