Update arize_callback.py (#6433)

Arize released a new Generative LLM Model Type, adjusting the callback
function to new logging.

Added arize imports, please delete if not necessary.

Specifically, this change makes sure that the prompt and response pairs
from LangChain agents are logged into Arize as a Generative LLM model,
instead of our previous categorical model. In order to do this, the
callback functions collects the necessary data and passes the data into
Arize using Python Pandas SDK.

Arize library, specifically pandas.logger is an additional dependency.

Notebook For Test:
https://docs.arize.com/arize/resources/integrations/langchain

Who can review?
Tag maintainers/contributors who might be interested:

@hwchase17 - project lead

Tracing / Callbacks

@agola11
This commit is contained in:
Hakan Tekgul 2023-06-19 18:33:49 -07:00 committed by GitHub
parent 00f276d23f
commit 6a157cf8bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,13 +1,14 @@
# Import the necessary packages for ingestion
import uuid import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import import_pandas
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
class ArizeCallbackHandler(BaseCallbackHandler): class ArizeCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Arize platform.""" """Callback Handler that logs to Arize."""
def __init__( def __init__(
self, self,
@ -19,41 +20,29 @@ class ArizeCallbackHandler(BaseCallbackHandler):
"""Initialize callback handler.""" """Initialize callback handler."""
super().__init__() super().__init__()
# Set the model_id and model_version for the Arize monitoring.
self.model_id = model_id self.model_id = model_id
self.model_version = model_version self.model_version = model_version
# Set the SPACE_KEY and API_KEY for the Arize client.
self.space_key = SPACE_KEY self.space_key = SPACE_KEY
self.api_key = API_KEY self.api_key = API_KEY
self.prompt_records: List[str] = []
# Initialize empty lists to store the prompt/response pairs self.response_records: List[str] = []
# and other necessary data. self.prediction_ids: List[str] = []
self.prompt_records: List = [] self.pred_timestamps: List[int] = []
self.response_records: List = [] self.response_embeddings: List[float] = []
self.prediction_ids: List = [] self.prompt_embeddings: List[float] = []
self.pred_timestamps: List = []
self.response_embeddings: List = []
self.prompt_embeddings: List = []
self.prompt_tokens = 0 self.prompt_tokens = 0
self.completion_tokens = 0 self.completion_tokens = 0
self.total_tokens = 0 self.total_tokens = 0
from arize.api import Client
from arize.pandas.embeddings import EmbeddingGenerator, UseCases from arize.pandas.embeddings import EmbeddingGenerator, UseCases
from arize.pandas.logger import Client
# Create an embedding generator for generating embeddings
# from prompts and responses.
self.generator = EmbeddingGenerator.from_use_case( self.generator = EmbeddingGenerator.from_use_case(
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION, use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
model_name="distilbert-base-uncased", model_name="distilbert-base-uncased",
tokenizer_max_length=512, tokenizer_max_length=512,
batch_size=256, batch_size=256,
) )
# Create an Arize client and check if the SPACE_KEY and API_KEY
# are not set to the default values.
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY) self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY": if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
raise ValueError("❌ CHANGE SPACE AND API KEYS") raise ValueError("❌ CHANGE SPACE AND API KEYS")
@ -63,32 +52,40 @@ class ArizeCallbackHandler(BaseCallbackHandler):
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
"""Record the prompts when an LLM starts."""
for prompt in prompts: for prompt in prompts:
self.prompt_records.append(prompt.replace("\n", " ")) self.prompt_records.append(prompt.replace("\n", ""))
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated.""" """Do nothing."""
pass pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log data to Arize when an LLM ends.""" pd = import_pandas()
from arize.utils.types import (
EmbeddingColumnNames,
Environments,
ModelTypes,
Schema,
)
import pandas as pd # Safe check if 'llm_output' and 'token_usage' exist
from arize.utils.types import Embedding, Environments, ModelTypes if response.llm_output and "token_usage" in response.llm_output:
self.prompt_tokens = response.llm_output["token_usage"].get(
"prompt_tokens", 0
)
self.total_tokens = response.llm_output["token_usage"].get(
"total_tokens", 0
)
self.completion_tokens = response.llm_output["token_usage"].get(
"completion_tokens", 0
)
else:
self.prompt_tokens = (
self.total_tokens
) = self.completion_tokens = 0 # assign default value
# Record token usage of the LLM
if response.llm_output is not None:
self.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"]
self.total_tokens = response.llm_output["token_usage"]["total_tokens"]
self.completion_tokens = response.llm_output["token_usage"][
"completion_tokens"
]
i = 0 i = 0
# Go through each prompt response pair and generate embeddings as
# well as timestamp and prediction ids
for generations in response.generations: for generations in response.generations:
for generation in generations: for generation in generations:
prompt = self.prompt_records[i] prompt = self.prompt_records[i]
@ -97,72 +94,98 @@ class ArizeCallbackHandler(BaseCallbackHandler):
text_col=pd.Series(prompt.replace("\n", " ")) text_col=pd.Series(prompt.replace("\n", " "))
).reset_index(drop=True) ).reset_index(drop=True)
) )
generated_text = generation.text.replace("\n", " ")
# Assigning text to response_text instead of response
response_text = generation.text.replace("\n", " ")
response_embedding = pd.Series( response_embedding = pd.Series(
self.generator.generate_embeddings( self.generator.generate_embeddings(
text_col=pd.Series(generation.text.replace("\n", " ")) text_col=pd.Series(generation.text.replace("\n", " "))
).reset_index(drop=True) ).reset_index(drop=True)
) )
pred_id = str(uuid.uuid4()) str(uuid.uuid4())
pred_timestamp = datetime.now().timestamp()
# Define embedding features for Arize ingestion # Define the columns and data
embedding_features = { columns = [
"prompt_embedding": Embedding( "prediction_ts",
vector=pd.Series(prompt_embedding[0]), data=prompt "response",
), "prompt",
"response_embedding": Embedding( "response_vector",
vector=pd.Series(response_embedding[0]), data=generated_text "prompt_vector",
), "prompt_token",
} "completion_token",
tags = { "total_token",
"Prompt Tokens": self.prompt_tokens, ]
"Completion Tokens": self.completion_tokens, data = [
"Total Tokens": self.total_tokens, [
} pred_timestamp,
response_text,
prompt,
response_embedding[0],
prompt_embedding[0],
self.prompt_tokens,
self.total_tokens,
self.completion_tokens,
]
]
# Log each prompt response data into arize # Create the DataFrame
future = self.arize_client.log( df = pd.DataFrame(data, columns=columns)
prediction_id=pred_id,
tags=tags, # Declare prompt and response columns
prediction_label="1", prompt_columns = EmbeddingColumnNames(
model_id=self.model_id, vector_column_name="prompt_vector", data_column_name="prompt"
model_type=ModelTypes.SCORE_CATEGORICAL,
model_version=self.model_version,
environment=Environments.PRODUCTION,
embedding_features=embedding_features,
) )
result = future.result() response_columns = EmbeddingColumnNames(
if result.status_code == 200: vector_column_name="response_vector", data_column_name="response"
)
schema = Schema(
timestamp_column_name="prediction_ts",
tag_column_names=[
"prompt_token",
"completion_token",
"total_token",
],
prompt_column_names=prompt_columns,
response_column_names=response_columns,
)
response_from_arize = self.arize_client.log(
dataframe=df,
schema=schema,
model_id=self.model_id,
model_version=self.model_version,
model_type=ModelTypes.GENERATIVE_LLM,
environment=Environments.PRODUCTION,
)
if response_from_arize.status_code == 200:
print("✅ Successfully logged data to Arize!") print("✅ Successfully logged data to Arize!")
else: else:
print( print(f'❌ Logging failed "{response_from_arize.text}"')
f"❌ Logging failed with status code {result.status_code}"
f' and message "{result.text}"'
)
i = i + 1 i = i + 1
def on_llm_error( def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
"""Do nothing when LLM outputs an error.""" """Do nothing."""
pass pass
def on_chain_start( def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Do nothing when LLM chain starts."""
pass pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing when LLM chain ends.""" """Do nothing."""
pass pass
def on_chain_error( def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
"""Do nothing when LLM chain outputs an error.""" """Do nothing."""
pass pass
def on_tool_start( def on_tool_start(
@ -171,11 +194,10 @@ class ArizeCallbackHandler(BaseCallbackHandler):
input_str: str, input_str: str,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Do nothing when tool starts."""
pass pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action.""" """Do nothing."""
pass pass
def on_tool_end( def on_tool_end(
@ -185,19 +207,15 @@ class ArizeCallbackHandler(BaseCallbackHandler):
llm_prefix: Optional[str] = None, llm_prefix: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Do nothing when tool ends."""
pass pass
def on_tool_error( def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
"""Do nothing when tool outputs an error."""
pass pass
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
pass pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
pass pass