mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
43db4cd20e
This PR updates the on_tool_end handlers to return the raw output from the tool instead of casting it to a string. This is technically a breaking change, though it's impact is expected to be somewhat minimal. It will fix behavior in `astream_events` as well. Fixes the following issue #18760 raised by @eyurtsev --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
214 lines
7.3 KiB
Python
214 lines
7.3 KiB
Python
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
from langchain_core.outputs import LLMResult
|
|
|
|
from langchain_community.callbacks.utils import import_pandas
|
|
|
|
|
|
class ArizeCallbackHandler(BaseCallbackHandler):
|
|
"""Callback Handler that logs to Arize."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_id: Optional[str] = None,
|
|
model_version: Optional[str] = None,
|
|
SPACE_KEY: Optional[str] = None,
|
|
API_KEY: Optional[str] = None,
|
|
) -> None:
|
|
"""Initialize callback handler."""
|
|
|
|
super().__init__()
|
|
self.model_id = model_id
|
|
self.model_version = model_version
|
|
self.space_key = SPACE_KEY
|
|
self.api_key = API_KEY
|
|
self.prompt_records: List[str] = []
|
|
self.response_records: List[str] = []
|
|
self.prediction_ids: List[str] = []
|
|
self.pred_timestamps: List[int] = []
|
|
self.response_embeddings: List[float] = []
|
|
self.prompt_embeddings: List[float] = []
|
|
self.prompt_tokens = 0
|
|
self.completion_tokens = 0
|
|
self.total_tokens = 0
|
|
self.step = 0
|
|
|
|
from arize.pandas.embeddings import EmbeddingGenerator, UseCases
|
|
from arize.pandas.logger import Client
|
|
|
|
self.generator = EmbeddingGenerator.from_use_case(
|
|
use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,
|
|
model_name="distilbert-base-uncased",
|
|
tokenizer_max_length=512,
|
|
batch_size=256,
|
|
)
|
|
self.arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)
|
|
if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
|
|
raise ValueError("❌ CHANGE SPACE AND API KEYS")
|
|
else:
|
|
print("✅ Arize client setup done! Now you can start using Arize!") # noqa: T201
|
|
|
|
def on_llm_start(
|
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
) -> None:
|
|
for prompt in prompts:
|
|
self.prompt_records.append(prompt.replace("\n", ""))
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
pd = import_pandas()
|
|
from arize.utils.types import (
|
|
EmbeddingColumnNames,
|
|
Environments,
|
|
ModelTypes,
|
|
Schema,
|
|
)
|
|
|
|
# Safe check if 'llm_output' and 'token_usage' exist
|
|
if response.llm_output and "token_usage" in response.llm_output:
|
|
self.prompt_tokens = response.llm_output["token_usage"].get(
|
|
"prompt_tokens", 0
|
|
)
|
|
self.total_tokens = response.llm_output["token_usage"].get(
|
|
"total_tokens", 0
|
|
)
|
|
self.completion_tokens = response.llm_output["token_usage"].get(
|
|
"completion_tokens", 0
|
|
)
|
|
else:
|
|
self.prompt_tokens = (
|
|
self.total_tokens
|
|
) = self.completion_tokens = 0 # assign default value
|
|
|
|
for generations in response.generations:
|
|
for generation in generations:
|
|
prompt = self.prompt_records[self.step]
|
|
self.step = self.step + 1
|
|
prompt_embedding = pd.Series(
|
|
self.generator.generate_embeddings(
|
|
text_col=pd.Series(prompt.replace("\n", " "))
|
|
).reset_index(drop=True)
|
|
)
|
|
|
|
# Assigning text to response_text instead of response
|
|
response_text = generation.text.replace("\n", " ")
|
|
response_embedding = pd.Series(
|
|
self.generator.generate_embeddings(
|
|
text_col=pd.Series(generation.text.replace("\n", " "))
|
|
).reset_index(drop=True)
|
|
)
|
|
pred_timestamp = datetime.now().timestamp()
|
|
|
|
# Define the columns and data
|
|
columns = [
|
|
"prediction_ts",
|
|
"response",
|
|
"prompt",
|
|
"response_vector",
|
|
"prompt_vector",
|
|
"prompt_token",
|
|
"completion_token",
|
|
"total_token",
|
|
]
|
|
data = [
|
|
[
|
|
pred_timestamp,
|
|
response_text,
|
|
prompt,
|
|
response_embedding[0],
|
|
prompt_embedding[0],
|
|
self.prompt_tokens,
|
|
self.total_tokens,
|
|
self.completion_tokens,
|
|
]
|
|
]
|
|
|
|
# Create the DataFrame
|
|
df = pd.DataFrame(data, columns=columns)
|
|
|
|
# Declare prompt and response columns
|
|
prompt_columns = EmbeddingColumnNames(
|
|
vector_column_name="prompt_vector", data_column_name="prompt"
|
|
)
|
|
|
|
response_columns = EmbeddingColumnNames(
|
|
vector_column_name="response_vector", data_column_name="response"
|
|
)
|
|
|
|
schema = Schema(
|
|
timestamp_column_name="prediction_ts",
|
|
tag_column_names=[
|
|
"prompt_token",
|
|
"completion_token",
|
|
"total_token",
|
|
],
|
|
prompt_column_names=prompt_columns,
|
|
response_column_names=response_columns,
|
|
)
|
|
|
|
response_from_arize = self.arize_client.log(
|
|
dataframe=df,
|
|
schema=schema,
|
|
model_id=self.model_id,
|
|
model_version=self.model_version,
|
|
model_type=ModelTypes.GENERATIVE_LLM,
|
|
environment=Environments.PRODUCTION,
|
|
)
|
|
if response_from_arize.status_code == 200:
|
|
print("✅ Successfully logged data to Arize!") # noqa: T201
|
|
else:
|
|
print(f'❌ Logging failed "{response_from_arize.text}"') # noqa: T201
|
|
|
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_chain_start(
|
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
) -> None:
|
|
pass
|
|
|
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
pass
|
|
|
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: Any,
|
|
observation_prefix: Optional[str] = None,
|
|
llm_prefix: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
pass
|
|
|
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
pass
|
|
|
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
|
pass
|
|
|
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
|
pass
|