mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
43db4cd20e
This PR updates the on_tool_end handlers to return the raw output from the tool instead of casting it to a string. This is technically a breaking change, though it's impact is expected to be somewhat minimal. It will fix behavior in `astream_events` as well. Fixes the following issue #18760 raised by @eyurtsev --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
297 lines
11 KiB
Python
297 lines
11 KiB
Python
"""ArthurAI's Callback Handler."""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import uuid
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from time import time
|
|
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
|
|
|
|
import numpy as np
|
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
from langchain_core.outputs import LLMResult
|
|
|
|
if TYPE_CHECKING:
|
|
import arthurai
|
|
from arthurai.core.models import ArthurModel
|
|
|
|
PROMPT_TOKENS = "prompt_tokens"
|
|
COMPLETION_TOKENS = "completion_tokens"
|
|
TOKEN_USAGE = "token_usage"
|
|
FINISH_REASON = "finish_reason"
|
|
DURATION = "duration"
|
|
|
|
|
|
def _lazy_load_arthur() -> arthurai:
|
|
"""Lazy load Arthur."""
|
|
try:
|
|
import arthurai
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"To use the ArthurCallbackHandler you need the"
|
|
" `arthurai` package. Please install it with"
|
|
" `pip install arthurai`.",
|
|
e,
|
|
)
|
|
|
|
return arthurai
|
|
|
|
|
|
class ArthurCallbackHandler(BaseCallbackHandler):
|
|
"""Callback Handler that logs to Arthur platform.
|
|
|
|
Arthur helps enterprise teams optimize model operations
|
|
and performance at scale. The Arthur API tracks model
|
|
performance, explainability, and fairness across tabular,
|
|
NLP, and CV models. Our API is model- and platform-agnostic,
|
|
and continuously scales with complex and dynamic enterprise needs.
|
|
To learn more about Arthur, visit our website at
|
|
https://www.arthur.ai/ or read the Arthur docs at
|
|
https://docs.arthur.ai/
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
arthur_model: ArthurModel,
|
|
) -> None:
|
|
"""Initialize callback handler."""
|
|
super().__init__()
|
|
arthurai = _lazy_load_arthur()
|
|
Stage = arthurai.common.constants.Stage
|
|
ValueType = arthurai.common.constants.ValueType
|
|
self.arthur_model = arthur_model
|
|
# save the attributes of this model to be used when preparing
|
|
# inferences to log to Arthur in on_llm_end()
|
|
self.attr_names = set([a.name for a in self.arthur_model.get_attributes()])
|
|
self.input_attr = [
|
|
x
|
|
for x in self.arthur_model.get_attributes()
|
|
if x.stage == Stage.ModelPipelineInput
|
|
and x.value_type == ValueType.Unstructured_Text
|
|
][0].name
|
|
self.output_attr = [
|
|
x
|
|
for x in self.arthur_model.get_attributes()
|
|
if x.stage == Stage.PredictedValue
|
|
and x.value_type == ValueType.Unstructured_Text
|
|
][0].name
|
|
self.token_likelihood_attr = None
|
|
if (
|
|
len(
|
|
[
|
|
x
|
|
for x in self.arthur_model.get_attributes()
|
|
if x.value_type == ValueType.TokenLikelihoods
|
|
]
|
|
)
|
|
> 0
|
|
):
|
|
self.token_likelihood_attr = [
|
|
x
|
|
for x in self.arthur_model.get_attributes()
|
|
if x.value_type == ValueType.TokenLikelihoods
|
|
][0].name
|
|
|
|
self.run_map: DefaultDict[str, Any] = defaultdict(dict)
|
|
|
|
@classmethod
|
|
def from_credentials(
|
|
cls,
|
|
model_id: str,
|
|
arthur_url: Optional[str] = "https://app.arthur.ai",
|
|
arthur_login: Optional[str] = None,
|
|
arthur_password: Optional[str] = None,
|
|
) -> ArthurCallbackHandler:
|
|
"""Initialize callback handler from Arthur credentials.
|
|
|
|
Args:
|
|
model_id (str): The ID of the arthur model to log to.
|
|
arthur_url (str, optional): The URL of the Arthur instance to log to.
|
|
Defaults to "https://app.arthur.ai".
|
|
arthur_login (str, optional): The login to use to connect to Arthur.
|
|
Defaults to None.
|
|
arthur_password (str, optional): The password to use to connect to
|
|
Arthur. Defaults to None.
|
|
|
|
Returns:
|
|
ArthurCallbackHandler: The initialized callback handler.
|
|
"""
|
|
arthurai = _lazy_load_arthur()
|
|
ArthurAI = arthurai.ArthurAI
|
|
ResponseClientError = arthurai.common.exceptions.ResponseClientError
|
|
|
|
# connect to Arthur
|
|
if arthur_login is None:
|
|
try:
|
|
arthur_api_key = os.environ["ARTHUR_API_KEY"]
|
|
except KeyError:
|
|
raise ValueError(
|
|
"No Arthur authentication provided. Either give"
|
|
" a login to the ArthurCallbackHandler"
|
|
" or set an ARTHUR_API_KEY as an environment variable."
|
|
)
|
|
arthur = ArthurAI(url=arthur_url, access_key=arthur_api_key)
|
|
else:
|
|
if arthur_password is None:
|
|
arthur = ArthurAI(url=arthur_url, login=arthur_login)
|
|
else:
|
|
arthur = ArthurAI(
|
|
url=arthur_url, login=arthur_login, password=arthur_password
|
|
)
|
|
# get model from Arthur by the provided model ID
|
|
try:
|
|
arthur_model = arthur.get_model(model_id)
|
|
except ResponseClientError:
|
|
raise ValueError(
|
|
f"Was unable to retrieve model with id {model_id} from Arthur."
|
|
" Make sure the ID corresponds to a model that is currently"
|
|
" registered with your Arthur account."
|
|
)
|
|
return cls(arthur_model)
|
|
|
|
def on_llm_start(
|
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
) -> None:
|
|
"""On LLM start, save the input prompts"""
|
|
run_id = kwargs["run_id"]
|
|
self.run_map[run_id]["input_texts"] = prompts
|
|
self.run_map[run_id]["start_time"] = time()
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
"""On LLM end, send data to Arthur."""
|
|
try:
|
|
import pytz
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Could not import pytz. Please install it with 'pip install pytz'."
|
|
) from e
|
|
|
|
run_id = kwargs["run_id"]
|
|
|
|
# get the run params from this run ID,
|
|
# or raise an error if this run ID has no corresponding metadata in self.run_map
|
|
try:
|
|
run_map_data = self.run_map[run_id]
|
|
except KeyError as e:
|
|
raise KeyError(
|
|
"This function has been called with a run_id"
|
|
" that was never registered in on_llm_start()."
|
|
" Restart and try running the LLM again"
|
|
) from e
|
|
|
|
# mark the duration time between on_llm_start() and on_llm_end()
|
|
time_from_start_to_end = time() - run_map_data["start_time"]
|
|
|
|
# create inferences to log to Arthur
|
|
inferences = []
|
|
for i, generations in enumerate(response.generations):
|
|
for generation in generations:
|
|
inference = {
|
|
"partner_inference_id": str(uuid.uuid4()),
|
|
"inference_timestamp": datetime.now(tz=pytz.UTC),
|
|
self.input_attr: run_map_data["input_texts"][i],
|
|
self.output_attr: generation.text,
|
|
}
|
|
|
|
if generation.generation_info is not None:
|
|
# add finish reason to the inference
|
|
# if generation info contains a finish reason and
|
|
# if the ArthurModel was registered to monitor finish_reason
|
|
if (
|
|
FINISH_REASON in generation.generation_info
|
|
and FINISH_REASON in self.attr_names
|
|
):
|
|
inference[FINISH_REASON] = generation.generation_info[
|
|
FINISH_REASON
|
|
]
|
|
|
|
# add token likelihoods data to the inference if the ArthurModel
|
|
# was registered to monitor token likelihoods
|
|
logprobs_data = generation.generation_info["logprobs"]
|
|
if (
|
|
logprobs_data is not None
|
|
and self.token_likelihood_attr is not None
|
|
):
|
|
logprobs = logprobs_data["top_logprobs"]
|
|
likelihoods = [
|
|
{k: np.exp(v) for k, v in logprobs[i].items()}
|
|
for i in range(len(logprobs))
|
|
]
|
|
inference[self.token_likelihood_attr] = likelihoods
|
|
|
|
# add token usage counts to the inference if the
|
|
# ArthurModel was registered to monitor token usage
|
|
if (
|
|
isinstance(response.llm_output, dict)
|
|
and TOKEN_USAGE in response.llm_output
|
|
):
|
|
token_usage = response.llm_output[TOKEN_USAGE]
|
|
if (
|
|
PROMPT_TOKENS in token_usage
|
|
and PROMPT_TOKENS in self.attr_names
|
|
):
|
|
inference[PROMPT_TOKENS] = token_usage[PROMPT_TOKENS]
|
|
if (
|
|
COMPLETION_TOKENS in token_usage
|
|
and COMPLETION_TOKENS in self.attr_names
|
|
):
|
|
inference[COMPLETION_TOKENS] = token_usage[COMPLETION_TOKENS]
|
|
|
|
# add inference duration to the inference if the ArthurModel
|
|
# was registered to monitor inference duration
|
|
if DURATION in self.attr_names:
|
|
inference[DURATION] = time_from_start_to_end
|
|
|
|
inferences.append(inference)
|
|
|
|
# send inferences to arthur
|
|
self.arthur_model.send_inferences(inferences)
|
|
|
|
def on_chain_start(
|
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
) -> None:
|
|
"""On chain start, do nothing."""
|
|
|
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
"""On chain end, do nothing."""
|
|
|
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Do nothing when LLM outputs an error."""
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
"""On new token, pass."""
|
|
|
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Do nothing when LLM chain outputs an error."""
|
|
|
|
def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Do nothing when tool starts."""
|
|
|
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
|
"""Do nothing when agent takes a specific action."""
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: Any,
|
|
observation_prefix: Optional[str] = None,
|
|
llm_prefix: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Do nothing when tool ends."""
|
|
|
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Do nothing when tool outputs an error."""
|
|
|
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
|
"""Do nothing"""
|
|
|
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
|
"""Do nothing"""
|