Update arize_callback.py (#6433)

Arize released a new Generative LLM Model Type, adjusting the callback
function to new logging.

Added arize imports, please delete if not necessary.

Specifically, this change makes sure that the prompt and response pairs
from LangChain agents are logged into Arize as a Generative LLM model,
instead of our previous categorical model. In order to do this, the
callback functions collects the necessary data and passes the data into
Arize using Python Pandas SDK.

Arize library, specifically pandas.logger is an additional dependency.

Notebook For Test:
https://docs.arize.com/arize/resources/integrations/langchain

Who can review?
Tag maintainers/contributors who might be interested:

@hwchase17 - project lead

Tracing / Callbacks

@agola11
master
Hakan Tekgul 11 months ago committed by GitHub
parent 00f276d23f
commit 6a157cf8bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,13 +1,14 @@
# Import the necessary packages for ingestion
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import import_pandas
from langchain.schema import AgentAction, AgentFinish, LLMResult
class ArizeCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs to Arize platform."""
"""Callback Handler that logs to Arize."""
def __init__(
self,
@ -19,41 +20,29 @@ class ArizeCallbackHandler(BaseCallbackHandler):
"""Initialize callback handler."""
super().__init__()
# Set the model_id and model_version for the Arize monitoring.
self.model_id = model_id
self.model_version = model_version
# Set the SPACE_KEY and API_KEY for the Arize client.
self.space_key = SPACE_KEY
self.api_key = API_KEY
# Initialize empty lists to store the prompt/response pairs
# and other necessary data.
self.prompt_records: List = []
self.response_records: List = []
self.prediction_ids: List = []
self.pred_timestamps: List = []
self.response_embeddings: List = []
self.prompt_embeddings: List = []
self.prompt_records: List[str] = []
self.response_records: List[str] = []
self.prediction_ids: List[str] = []
self.pred_timestamps: List[int] = []
self.response_embeddings: List[float] = []
self.prompt_embeddings: List[float] = []
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
from arize.api import Client
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
from arize.pandas.logger import Client
# Create an embedding generator for generating embeddings
# from prompts and responses.
self.generator = EmbeddingGenerator.from_use_case(
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
model_name="distilbert-base-uncased",
tokenizer_max_length=512,
batch_size=256,
)
# Create an Arize client and check if the SPACE_KEY and API_KEY
# are not set to the default values.
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
raise ValueError("❌ CHANGE SPACE AND API KEYS")
@ -63,32 +52,40 @@ class ArizeCallbackHandler(BaseCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Record the prompts when an LLM starts."""
for prompt in prompts:
self.prompt_records.append(prompt.replace("\n", " "))
self.prompt_records.append(prompt.replace("\n", ""))
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
"""Do nothing."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log data to Arize when an LLM ends."""
import pandas as pd
from arize.utils.types import Embedding, Environments, ModelTypes
# Record token usage of the LLM
if response.llm_output is not None:
self.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"]
self.total_tokens = response.llm_output["token_usage"]["total_tokens"]
self.completion_tokens = response.llm_output["token_usage"][
"completion_tokens"
]
pd = import_pandas()
from arize.utils.types import (
EmbeddingColumnNames,
Environments,
ModelTypes,
Schema,
)
# Safe check if 'llm_output' and 'token_usage' exist
if response.llm_output and "token_usage" in response.llm_output:
self.prompt_tokens = response.llm_output["token_usage"].get(
"prompt_tokens", 0
)
self.total_tokens = response.llm_output["token_usage"].get(
"total_tokens", 0
)
self.completion_tokens = response.llm_output["token_usage"].get(
"completion_tokens", 0
)
else:
self.prompt_tokens = (
self.total_tokens
) = self.completion_tokens = 0 # assign default value
i = 0
# Go through each prompt response pair and generate embeddings as
# well as timestamp and prediction ids
for generations in response.generations:
for generation in generations:
prompt = self.prompt_records[i]
@ -97,72 +94,98 @@ class ArizeCallbackHandler(BaseCallbackHandler):
text_col=pd.Series(prompt.replace("\n", " "))
).reset_index(drop=True)
)
generated_text = generation.text.replace("\n", " ")
# Assigning text to response_text instead of response
response_text = generation.text.replace("\n", " ")
response_embedding = pd.Series(
self.generator.generate_embeddings(
text_col=pd.Series(generation.text.replace("\n", " "))
).reset_index(drop=True)
)
pred_id = str(uuid.uuid4())
# Define embedding features for Arize ingestion
embedding_features = {
"prompt_embedding": Embedding(
vector=pd.Series(prompt_embedding[0]), data=prompt
),
"response_embedding": Embedding(
vector=pd.Series(response_embedding[0]), data=generated_text
),
}
tags = {
"Prompt Tokens": self.prompt_tokens,
"Completion Tokens": self.completion_tokens,
"Total Tokens": self.total_tokens,
}
# Log each prompt response data into arize
future = self.arize_client.log(
prediction_id=pred_id,
tags=tags,
prediction_label="1",
str(uuid.uuid4())
pred_timestamp = datetime.now().timestamp()
# Define the columns and data
columns = [
"prediction_ts",
"response",
"prompt",
"response_vector",
"prompt_vector",
"prompt_token",
"completion_token",
"total_token",
]
data = [
[
pred_timestamp,
response_text,
prompt,
response_embedding[0],
prompt_embedding[0],
self.prompt_tokens,
self.total_tokens,
self.completion_tokens,
]
]
# Create the DataFrame
df = pd.DataFrame(data, columns=columns)
# Declare prompt and response columns
prompt_columns = EmbeddingColumnNames(
vector_column_name="prompt_vector", data_column_name="prompt"
)
response_columns = EmbeddingColumnNames(
vector_column_name="response_vector", data_column_name="response"
)
schema = Schema(
timestamp_column_name="prediction_ts",
tag_column_names=[
"prompt_token",
"completion_token",
"total_token",
],
prompt_column_names=prompt_columns,
response_column_names=response_columns,
)
response_from_arize = self.arize_client.log(
dataframe=df,
schema=schema,
model_id=self.model_id,
model_type=ModelTypes.SCORE_CATEGORICAL,
model_version=self.model_version,
model_type=ModelTypes.GENERATIVE_LLM,
environment=Environments.PRODUCTION,
embedding_features=embedding_features,
)
result = future.result()
if result.status_code == 200:
if response_from_arize.status_code == 200:
print("✅ Successfully logged data to Arize!")
else:
print(
f"❌ Logging failed with status code {result.status_code}"
f' and message "{result.text}"'
)
print(f'❌ Logging failed "{response_from_arize.text}"')
i = i + 1
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM outputs an error."""
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Do nothing when LLM chain starts."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing when LLM chain ends."""
"""Do nothing."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM chain outputs an error."""
"""Do nothing."""
pass
def on_tool_start(
@ -171,11 +194,10 @@ class ArizeCallbackHandler(BaseCallbackHandler):
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
"""Do nothing."""
pass
def on_tool_end(
@ -185,19 +207,15 @@ class ArizeCallbackHandler(BaseCallbackHandler):
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
pass

Loading…
Cancel
Save