mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Update arize_callback.py (#6433)
Arize released a new Generative LLM Model Type, adjusting the callback function to new logging. Added arize imports, please delete if not necessary. Specifically, this change makes sure that the prompt and response pairs from LangChain agents are logged into Arize as a Generative LLM model, instead of our previous categorical model. In order to do this, the callback functions collects the necessary data and passes the data into Arize using Python Pandas SDK. Arize library, specifically pandas.logger is an additional dependency. Notebook For Test: https://docs.arize.com/arize/resources/integrations/langchain Who can review? Tag maintainers/contributors who might be interested: @hwchase17 - project lead Tracing / Callbacks @agola11
This commit is contained in:
parent
00f276d23f
commit
6a157cf8bb
@ -1,13 +1,14 @@
|
|||||||
# Import the necessary packages for ingestion
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.callbacks.utils import import_pandas
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
|
||||||
|
|
||||||
class ArizeCallbackHandler(BaseCallbackHandler):
|
class ArizeCallbackHandler(BaseCallbackHandler):
|
||||||
"""Callback Handler that logs to Arize platform."""
|
"""Callback Handler that logs to Arize."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -19,41 +20,29 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Initialize callback handler."""
|
"""Initialize callback handler."""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Set the model_id and model_version for the Arize monitoring.
|
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.model_version = model_version
|
self.model_version = model_version
|
||||||
|
|
||||||
# Set the SPACE_KEY and API_KEY for the Arize client.
|
|
||||||
self.space_key = SPACE_KEY
|
self.space_key = SPACE_KEY
|
||||||
self.api_key = API_KEY
|
self.api_key = API_KEY
|
||||||
|
self.prompt_records: List[str] = []
|
||||||
# Initialize empty lists to store the prompt/response pairs
|
self.response_records: List[str] = []
|
||||||
# and other necessary data.
|
self.prediction_ids: List[str] = []
|
||||||
self.prompt_records: List = []
|
self.pred_timestamps: List[int] = []
|
||||||
self.response_records: List = []
|
self.response_embeddings: List[float] = []
|
||||||
self.prediction_ids: List = []
|
self.prompt_embeddings: List[float] = []
|
||||||
self.pred_timestamps: List = []
|
|
||||||
self.response_embeddings: List = []
|
|
||||||
self.prompt_embeddings: List = []
|
|
||||||
self.prompt_tokens = 0
|
self.prompt_tokens = 0
|
||||||
self.completion_tokens = 0
|
self.completion_tokens = 0
|
||||||
self.total_tokens = 0
|
self.total_tokens = 0
|
||||||
|
|
||||||
from arize.api import Client
|
|
||||||
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
|
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
|
||||||
|
from arize.pandas.logger import Client
|
||||||
|
|
||||||
# Create an embedding generator for generating embeddings
|
|
||||||
# from prompts and responses.
|
|
||||||
self.generator = EmbeddingGenerator.from_use_case(
|
self.generator = EmbeddingGenerator.from_use_case(
|
||||||
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
|
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
|
||||||
model_name="distilbert-base-uncased",
|
model_name="distilbert-base-uncased",
|
||||||
tokenizer_max_length=512,
|
tokenizer_max_length=512,
|
||||||
batch_size=256,
|
batch_size=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create an Arize client and check if the SPACE_KEY and API_KEY
|
|
||||||
# are not set to the default values.
|
|
||||||
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
|
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
|
||||||
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
|
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
|
||||||
raise ValueError("❌ CHANGE SPACE AND API KEYS")
|
raise ValueError("❌ CHANGE SPACE AND API KEYS")
|
||||||
@ -63,32 +52,40 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record the prompts when an LLM starts."""
|
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
self.prompt_records.append(prompt.replace("\n", " "))
|
self.prompt_records.append(prompt.replace("\n", ""))
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
"""Do nothing when a new token is generated."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Log data to Arize when an LLM ends."""
|
pd = import_pandas()
|
||||||
|
from arize.utils.types import (
|
||||||
|
EmbeddingColumnNames,
|
||||||
|
Environments,
|
||||||
|
ModelTypes,
|
||||||
|
Schema,
|
||||||
|
)
|
||||||
|
|
||||||
import pandas as pd
|
# Safe check if 'llm_output' and 'token_usage' exist
|
||||||
from arize.utils.types import Embedding, Environments, ModelTypes
|
if response.llm_output and "token_usage" in response.llm_output:
|
||||||
|
self.prompt_tokens = response.llm_output["token_usage"].get(
|
||||||
|
"prompt_tokens", 0
|
||||||
|
)
|
||||||
|
self.total_tokens = response.llm_output["token_usage"].get(
|
||||||
|
"total_tokens", 0
|
||||||
|
)
|
||||||
|
self.completion_tokens = response.llm_output["token_usage"].get(
|
||||||
|
"completion_tokens", 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.prompt_tokens = (
|
||||||
|
self.total_tokens
|
||||||
|
) = self.completion_tokens = 0 # assign default value
|
||||||
|
|
||||||
# Record token usage of the LLM
|
|
||||||
if response.llm_output is not None:
|
|
||||||
self.prompt_tokens = response.llm_output["token_usage"]["prompt_tokens"]
|
|
||||||
self.total_tokens = response.llm_output["token_usage"]["total_tokens"]
|
|
||||||
self.completion_tokens = response.llm_output["token_usage"][
|
|
||||||
"completion_tokens"
|
|
||||||
]
|
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
# Go through each prompt response pair and generate embeddings as
|
|
||||||
# well as timestamp and prediction ids
|
|
||||||
for generations in response.generations:
|
for generations in response.generations:
|
||||||
for generation in generations:
|
for generation in generations:
|
||||||
prompt = self.prompt_records[i]
|
prompt = self.prompt_records[i]
|
||||||
@ -97,72 +94,98 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
|||||||
text_col=pd.Series(prompt.replace("\n", " "))
|
text_col=pd.Series(prompt.replace("\n", " "))
|
||||||
).reset_index(drop=True)
|
).reset_index(drop=True)
|
||||||
)
|
)
|
||||||
generated_text = generation.text.replace("\n", " ")
|
|
||||||
|
# Assigning text to response_text instead of response
|
||||||
|
response_text = generation.text.replace("\n", " ")
|
||||||
response_embedding = pd.Series(
|
response_embedding = pd.Series(
|
||||||
self.generator.generate_embeddings(
|
self.generator.generate_embeddings(
|
||||||
text_col=pd.Series(generation.text.replace("\n", " "))
|
text_col=pd.Series(generation.text.replace("\n", " "))
|
||||||
).reset_index(drop=True)
|
).reset_index(drop=True)
|
||||||
)
|
)
|
||||||
pred_id = str(uuid.uuid4())
|
str(uuid.uuid4())
|
||||||
|
pred_timestamp = datetime.now().timestamp()
|
||||||
|
|
||||||
# Define embedding features for Arize ingestion
|
# Define the columns and data
|
||||||
embedding_features = {
|
columns = [
|
||||||
"prompt_embedding": Embedding(
|
"prediction_ts",
|
||||||
vector=pd.Series(prompt_embedding[0]), data=prompt
|
"response",
|
||||||
),
|
"prompt",
|
||||||
"response_embedding": Embedding(
|
"response_vector",
|
||||||
vector=pd.Series(response_embedding[0]), data=generated_text
|
"prompt_vector",
|
||||||
),
|
"prompt_token",
|
||||||
}
|
"completion_token",
|
||||||
tags = {
|
"total_token",
|
||||||
"Prompt Tokens": self.prompt_tokens,
|
]
|
||||||
"Completion Tokens": self.completion_tokens,
|
data = [
|
||||||
"Total Tokens": self.total_tokens,
|
[
|
||||||
}
|
pred_timestamp,
|
||||||
|
response_text,
|
||||||
|
prompt,
|
||||||
|
response_embedding[0],
|
||||||
|
prompt_embedding[0],
|
||||||
|
self.prompt_tokens,
|
||||||
|
self.total_tokens,
|
||||||
|
self.completion_tokens,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
# Log each prompt response data into arize
|
# Create the DataFrame
|
||||||
future = self.arize_client.log(
|
df = pd.DataFrame(data, columns=columns)
|
||||||
prediction_id=pred_id,
|
|
||||||
tags=tags,
|
# Declare prompt and response columns
|
||||||
prediction_label="1",
|
prompt_columns = EmbeddingColumnNames(
|
||||||
model_id=self.model_id,
|
vector_column_name="prompt_vector", data_column_name="prompt"
|
||||||
model_type=ModelTypes.SCORE_CATEGORICAL,
|
|
||||||
model_version=self.model_version,
|
|
||||||
environment=Environments.PRODUCTION,
|
|
||||||
embedding_features=embedding_features,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = future.result()
|
response_columns = EmbeddingColumnNames(
|
||||||
if result.status_code == 200:
|
vector_column_name="response_vector", data_column_name="response"
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = Schema(
|
||||||
|
timestamp_column_name="prediction_ts",
|
||||||
|
tag_column_names=[
|
||||||
|
"prompt_token",
|
||||||
|
"completion_token",
|
||||||
|
"total_token",
|
||||||
|
],
|
||||||
|
prompt_column_names=prompt_columns,
|
||||||
|
response_column_names=response_columns,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_from_arize = self.arize_client.log(
|
||||||
|
dataframe=df,
|
||||||
|
schema=schema,
|
||||||
|
model_id=self.model_id,
|
||||||
|
model_version=self.model_version,
|
||||||
|
model_type=ModelTypes.GENERATIVE_LLM,
|
||||||
|
environment=Environments.PRODUCTION,
|
||||||
|
)
|
||||||
|
if response_from_arize.status_code == 200:
|
||||||
print("✅ Successfully logged data to Arize!")
|
print("✅ Successfully logged data to Arize!")
|
||||||
else:
|
else:
|
||||||
print(
|
print(f'❌ Logging failed "{response_from_arize.text}"')
|
||||||
f"❌ Logging failed with status code {result.status_code}"
|
|
||||||
f' and message "{result.text}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
i = i + 1
|
i = i + 1
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Do nothing when LLM outputs an error."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Do nothing when LLM chain starts."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""Do nothing when LLM chain ends."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Do nothing when LLM chain outputs an error."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
@ -171,11 +194,10 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
|||||||
input_str: str,
|
input_str: str,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Do nothing when tool starts."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||||
"""Do nothing when agent takes a specific action."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
@ -185,19 +207,15 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
|||||||
llm_prefix: Optional[str] = None,
|
llm_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Do nothing when tool ends."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Do nothing when tool outputs an error."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
"""Do nothing"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||||
"""Do nothing"""
|
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user