mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
791d59a2c8
Issue: we have several helper functions to import third-party libraries like import_uptrain in [community.callbacks](https://api.python.langchain.com/en/latest/callbacks/langchain_community.callbacks.uptrain_callback.import_uptrain.html#langchain_community.callbacks.uptrain_callback.import_uptrain). And we have core.utils.utils.guard_import that works exactly for this purpose. The import_<package> functions work inconsistently and rather be private functions. Change: replaced these functions with the guard_import function. Related to #21133
252 lines
8.6 KiB
Python
252 lines
8.6 KiB
Python
import time
|
|
from typing import Any, Dict, List, Optional, cast
|
|
|
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
from langchain_core.messages import BaseMessage
|
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
|
from langchain_core.utils import guard_import
|
|
|
|
|
|
def import_infino() -> Any:
|
|
"""Import the infino client."""
|
|
return guard_import("infinopy").InfinoClient()
|
|
|
|
|
|
def import_tiktoken() -> Any:
|
|
"""Import tiktoken for counting tokens for OpenAI models."""
|
|
return guard_import("tiktoken")
|
|
|
|
|
|
def get_num_tokens(string: str, openai_model_name: str) -> int:
|
|
"""Calculate num tokens for OpenAI with tiktoken package.
|
|
|
|
Official documentation: https://github.com/openai/openai-cookbook/blob/main
|
|
/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
"""
|
|
tiktoken = import_tiktoken()
|
|
|
|
encoding = tiktoken.encoding_for_model(openai_model_name)
|
|
num_tokens = len(encoding.encode(string))
|
|
return num_tokens
|
|
|
|
|
|
class InfinoCallbackHandler(BaseCallbackHandler):
|
|
"""Callback Handler that logs to Infino."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_id: Optional[str] = None,
|
|
model_version: Optional[str] = None,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
# Set Infino client
|
|
self.client = import_infino()
|
|
self.model_id = model_id
|
|
self.model_version = model_version
|
|
self.verbose = verbose
|
|
self.is_chat_openai_model = False
|
|
self.chat_openai_model_name = "gpt-3.5-turbo"
|
|
|
|
def _send_to_infino(
|
|
self,
|
|
key: str,
|
|
value: Any,
|
|
is_ts: bool = True,
|
|
) -> None:
|
|
"""Send the key-value to Infino.
|
|
|
|
Parameters:
|
|
key (str): the key to send to Infino.
|
|
value (Any): the value to send to Infino.
|
|
is_ts (bool): if True, the value is part of a time series, else it
|
|
is sent as a log message.
|
|
"""
|
|
payload = {
|
|
"date": int(time.time()),
|
|
key: value,
|
|
"labels": {
|
|
"model_id": self.model_id,
|
|
"model_version": self.model_version,
|
|
},
|
|
}
|
|
if self.verbose:
|
|
print(f"Tracking {key} with Infino: {payload}") # noqa: T201
|
|
|
|
# Append to Infino time series only if is_ts is True, otherwise
|
|
# append to Infino log.
|
|
if is_ts:
|
|
self.client.append_ts(payload)
|
|
else:
|
|
self.client.append_log(payload)
|
|
|
|
def on_llm_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
prompts: List[str],
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Log the prompts to Infino, and set start time and error flag."""
|
|
for prompt in prompts:
|
|
self._send_to_infino("prompt", prompt, is_ts=False)
|
|
|
|
# Set the error flag to indicate no error (this will get overridden
|
|
# in on_llm_error if an error occurs).
|
|
self.error = 0
|
|
|
|
# Set the start time (so that we can calculate the request
|
|
# duration in on_llm_end).
|
|
self.start_time = time.time()
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
"""Do nothing when a new token is generated."""
|
|
pass
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
"""Log the latency, error, token usage, and response to Infino."""
|
|
# Calculate and track the request latency.
|
|
self.end_time = time.time()
|
|
duration = self.end_time - self.start_time
|
|
self._send_to_infino("latency", duration)
|
|
|
|
# Track success or error flag.
|
|
self._send_to_infino("error", self.error)
|
|
|
|
# Track prompt response.
|
|
for generations in response.generations:
|
|
for generation in generations:
|
|
self._send_to_infino("prompt_response", generation.text, is_ts=False)
|
|
|
|
# Track token usage (for non-chat models).
|
|
if (response.llm_output is not None) and isinstance(response.llm_output, Dict):
|
|
token_usage = response.llm_output["token_usage"]
|
|
if token_usage is not None:
|
|
prompt_tokens = token_usage["prompt_tokens"]
|
|
total_tokens = token_usage["total_tokens"]
|
|
completion_tokens = token_usage["completion_tokens"]
|
|
self._send_to_infino("prompt_tokens", prompt_tokens)
|
|
self._send_to_infino("total_tokens", total_tokens)
|
|
self._send_to_infino("completion_tokens", completion_tokens)
|
|
|
|
# Track completion token usage (for openai chat models).
|
|
if self.is_chat_openai_model:
|
|
messages = " ".join(
|
|
cast(str, cast(ChatGeneration, generation).message.content)
|
|
for generation in generations
|
|
)
|
|
completion_tokens = get_num_tokens(
|
|
messages, openai_model_name=self.chat_openai_model_name
|
|
)
|
|
self._send_to_infino("completion_tokens", completion_tokens)
|
|
|
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Set the error flag."""
|
|
self.error = 1
|
|
|
|
def on_chain_start(
|
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
) -> None:
|
|
"""Do nothing when LLM chain starts."""
|
|
pass
|
|
|
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
"""Do nothing when LLM chain ends."""
|
|
pass
|
|
|
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Need to log the error."""
|
|
pass
|
|
|
|
def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Do nothing when tool starts."""
|
|
pass
|
|
|
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
|
"""Do nothing when agent takes a specific action."""
|
|
pass
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: str,
|
|
observation_prefix: Optional[str] = None,
|
|
llm_prefix: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Do nothing when tool ends."""
|
|
pass
|
|
|
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Do nothing when tool outputs an error."""
|
|
pass
|
|
|
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
|
"""Do nothing."""
|
|
pass
|
|
|
|
def on_chat_model_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
messages: List[List[BaseMessage]],
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Run when LLM starts running."""
|
|
|
|
# Currently, for chat models, we only support input prompts for ChatOpenAI.
|
|
# Check if this model is a ChatOpenAI model.
|
|
values = serialized.get("id")
|
|
if values:
|
|
for value in values:
|
|
if value == "ChatOpenAI":
|
|
self.is_chat_openai_model = True
|
|
break
|
|
|
|
# Track prompt tokens for ChatOpenAI model.
|
|
if self.is_chat_openai_model:
|
|
invocation_params = kwargs.get("invocation_params")
|
|
if invocation_params:
|
|
model_name = invocation_params.get("model_name")
|
|
if model_name:
|
|
self.chat_openai_model_name = model_name
|
|
prompt_tokens = 0
|
|
for message_list in messages:
|
|
message_string = " ".join(
|
|
cast(str, msg.content) for msg in message_list
|
|
)
|
|
num_tokens = get_num_tokens(
|
|
message_string,
|
|
openai_model_name=self.chat_openai_model_name,
|
|
)
|
|
prompt_tokens += num_tokens
|
|
|
|
self._send_to_infino("prompt_tokens", prompt_tokens)
|
|
|
|
if self.verbose:
|
|
print( # noqa: T201
|
|
f"on_chat_model_start: is_chat_openai_model= \
|
|
{self.is_chat_openai_model}, \
|
|
chat_openai_model_name={self.chat_openai_model_name}"
|
|
)
|
|
|
|
# Send the prompt to infino
|
|
prompt = " ".join(
|
|
cast(str, msg.content) for sublist in messages for msg in sublist
|
|
)
|
|
self._send_to_infino("prompt", prompt, is_ts=False)
|
|
|
|
# Set the error flag to indicate no error (this will get overridden
|
|
# in on_llm_error if an error occurs).
|
|
self.error = 0
|
|
|
|
# Set the start time (so that we can calculate the request
|
|
# duration in on_llm_end).
|
|
self.start_time = time.time()
|