From 6a157cf8bbc113ff2192b57045b42633c82484cf Mon Sep 17 00:00:00 2001 From: Hakan Tekgul Date: Mon, 19 Jun 2023 18:33:49 -0700 Subject: [PATCH] Update arize_callback.py (#6433) Arize released a new Generative LLM Model Type, adjusting the callback function to new logging. Added arize imports, please delete if not necessary. Specifically, this change makes sure that the prompt and response pairs from LangChain agents are logged into Arize as a Generative LLM model, instead of our previous categorical model. In order to do this, the callback functions collects the necessary data and passes the data into Arize using Python Pandas SDK. Arize library, specifically pandas.logger is an additional dependency. Notebook For Test: https://docs.arize.com/arize/resources/integrations/langchain Who can review? Tag maintainers/contributors who might be interested: @hwchase17 - project lead Tracing / Callbacks @agola11 --- langchain/callbacks/arize_callback.py | 180 ++++++++++++++------------ 1 file changed, 99 insertions(+), 81 deletions(-) diff --git a/langchain/callbacks/arize_callback.py b/langchain/callbacks/arize_callback.py index a93dddfe..7e1196e7 100644 --- a/langchain/callbacks/arize_callback.py +++ b/langchain/callbacks/arize_callback.py @@ -1,13 +1,14 @@ -# Import the necessary packages for ingestion import uuid +from datetime import datetime from typing import Any, Dict, List, Optional, Union from langchain.callbacks.base import BaseCallbackHandler +from langchain.callbacks.utils import import_pandas from langchain.schema import AgentAction, AgentFinish, LLMResult class ArizeCallbackHandler(BaseCallbackHandler): - """Callback Handler that logs to Arize platform.""" + """Callback Handler that logs to Arize.""" def __init__( self, @@ -19,41 +20,29 @@ class ArizeCallbackHandler(BaseCallbackHandler): """Initialize callback handler.""" super().__init__() - - # Set the model_id and model_version for the Arize monitoring. self.model_id = model_id self.model_version = model_version - - # Set the SPACE_KEY and API_KEY for the Arize client. self.space_key = SPACE_KEY self.api_key = API_KEY - - # Initialize empty lists to store the prompt/response pairs - # and other necessary data. - self.prompt_records: List = [] - self.response_records: List = [] - self.prediction_ids: List = [] - self.pred_timestamps: List = [] - self.response_embeddings: List = [] - self.prompt_embeddings: List = [] + self.prompt_records: List[str] = [] + self.response_records: List[str] = [] + self.prediction_ids: List[str] = [] + self.pred_timestamps: List[int] = [] + self.response_embeddings: List[float] = [] + self.prompt_embeddings: List[float] = [] self.prompt_tokens = 0 self.completion_tokens = 0 self.total_tokens = 0 - from arize.api import Client from arize.pandas.embeddings import EmbeddingGenerator, UseCases + from arize.pandas.logger import Client - # Create an embedding generator for generating embeddings - # from prompts and responses. self.generator = EmbeddingGenerator.from_use_case( use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION, model_name="distilbert-base-uncased", tokenizer_max_length=512, batch_size=256, ) - - # Create an Arize client and check if the SPACE_KEY and API_KEY - # are not set to the default values. self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY) if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY": raise ValueError("❌ CHANGE SPACE AND API KEYS") @@ -63,32 +52,40 @@ class ArizeCallbackHandler(BaseCallbackHandler): def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: - """Record the prompts when an LLM starts.""" - for prompt in prompts: - self.prompt_records.append(prompt.replace("\n", " ")) + self.prompt_records.append(prompt.replace("\n", "")) def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Do nothing when a new token is generated.""" + """Do nothing.""" pass def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Log data to Arize when an LLM ends.""" - - import pandas as pd - from arize.utils.types import Embedding, Environments, ModelTypes - - # Record token usage of the LLM - if response.llm_output is not None: - self.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"] - self.total_tokens = response.llm_output["token_usage"]["total_tokens"] - self.completion_tokens = response.llm_output["token_usage"][ - "completion_tokens" - ] + pd = import_pandas() + from arize.utils.types import ( + EmbeddingColumnNames, + Environments, + ModelTypes, + Schema, + ) + + # Safe check if 'llm_output' and 'token_usage' exist + if response.llm_output and "token_usage" in response.llm_output: + self.prompt_tokens = response.llm_output["token_usage"].get( + "prompt_tokens", 0 + ) + self.total_tokens = response.llm_output["token_usage"].get( + "total_tokens", 0 + ) + self.completion_tokens = response.llm_output["token_usage"].get( + "completion_tokens", 0 + ) + else: + self.prompt_tokens = ( + self.total_tokens + ) = self.completion_tokens = 0 # assign default value + i = 0 - # Go through each prompt response pair and generate embeddings as - # well as timestamp and prediction ids for generations in response.generations: for generation in generations: prompt = self.prompt_records[i] @@ -97,72 +94,98 @@ class ArizeCallbackHandler(BaseCallbackHandler): text_col=pd.Series(prompt.replace("\n", " ")) ).reset_index(drop=True) ) - generated_text = generation.text.replace("\n", " ") + + # Assigning text to response_text instead of response + response_text = generation.text.replace("\n", " ") response_embedding = pd.Series( self.generator.generate_embeddings( text_col=pd.Series(generation.text.replace("\n", " ")) ).reset_index(drop=True) ) - pred_id = str(uuid.uuid4()) - - # Define embedding features for Arize ingestion - embedding_features = { - "prompt_embedding": Embedding( - vector=pd.Series(prompt_embedding[0]), data=prompt - ), - "response_embedding": Embedding( - vector=pd.Series(response_embedding[0]), data=generated_text - ), - } - tags = { - "Prompt Tokens": self.prompt_tokens, - "Completion Tokens": self.completion_tokens, - "Total Tokens": self.total_tokens, - } - - # Log each prompt response data into arize - future = self.arize_client.log( - prediction_id=pred_id, - tags=tags, - prediction_label="1", + str(uuid.uuid4()) + pred_timestamp = datetime.now().timestamp() + + # Define the columns and data + columns = [ + "prediction_ts", + "response", + "prompt", + "response_vector", + "prompt_vector", + "prompt_token", + "completion_token", + "total_token", + ] + data = [ + [ + pred_timestamp, + response_text, + prompt, + response_embedding[0], + prompt_embedding[0], + self.prompt_tokens, + self.total_tokens, + self.completion_tokens, + ] + ] + + # Create the DataFrame + df = pd.DataFrame(data, columns=columns) + + # Declare prompt and response columns + prompt_columns = EmbeddingColumnNames( + vector_column_name="prompt_vector", data_column_name="prompt" + ) + + response_columns = EmbeddingColumnNames( + vector_column_name="response_vector", data_column_name="response" + ) + + schema = Schema( + timestamp_column_name="prediction_ts", + tag_column_names=[ + "prompt_token", + "completion_token", + "total_token", + ], + prompt_column_names=prompt_columns, + response_column_names=response_columns, + ) + + response_from_arize = self.arize_client.log( + dataframe=df, + schema=schema, model_id=self.model_id, - model_type=ModelTypes.SCORE_CATEGORICAL, model_version=self.model_version, + model_type=ModelTypes.GENERATIVE_LLM, environment=Environments.PRODUCTION, - embedding_features=embedding_features, ) - - result = future.result() - if result.status_code == 200: + if response_from_arize.status_code == 200: print("✅ Successfully logged data to Arize!") else: - print( - f"❌ Logging failed with status code {result.status_code}" - f' and message "{result.text}"' - ) + print(f'❌ Logging failed "{response_from_arize.text}"') i = i + 1 def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: - """Do nothing when LLM outputs an error.""" + """Do nothing.""" pass def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: - """Do nothing when LLM chain starts.""" pass def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Do nothing when LLM chain ends.""" + """Do nothing.""" pass def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: - """Do nothing when LLM chain outputs an error.""" + """Do nothing.""" pass def on_tool_start( @@ -171,11 +194,10 @@ class ArizeCallbackHandler(BaseCallbackHandler): input_str: str, **kwargs: Any, ) -> None: - """Do nothing when tool starts.""" pass def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Do nothing when agent takes a specific action.""" + """Do nothing.""" pass def on_tool_end( @@ -185,19 +207,15 @@ class ArizeCallbackHandler(BaseCallbackHandler): llm_prefix: Optional[str] = None, **kwargs: Any, ) -> None: - """Do nothing when tool ends.""" pass def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: - """Do nothing when tool outputs an error.""" pass def on_text(self, text: str, **kwargs: Any) -> None: - """Do nothing""" pass def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Do nothing""" pass