@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
from langchain . callbacks . base import BaseCallbackHandler
from langchain . callbacks . base import BaseCallbackHandler
from langchain . schema import AgentAction , AgentFinish , LLMResult
from langchain . schema import AgentAction , AgentFinish , LLMResult
from langchain . schema . messages import BaseMessage
def import_infino ( ) - > Any :
def import_infino ( ) - > Any :
@ -18,6 +19,32 @@ def import_infino() -> Any:
return InfinoClient ( )
return InfinoClient ( )
def import_tiktoken ( ) - > Any :
""" Import tiktoken for counting tokens for OpenAI models. """
try :
import tiktoken
except ImportError :
raise ImportError (
" To use the ChatOpenAI model with Infino callback manager, you need to "
" have the `tiktoken` python package installed. "
" Please install it with `pip install tiktoken` "
)
return tiktoken
def get_num_tokens ( string : str , openai_model_name : str ) - > int :
""" Calculate num tokens for OpenAI with tiktoken package.
Official documentation : https : / / github . com / openai / openai - cookbook / blob / main
/ examples / How_to_count_tokens_with_tiktoken . ipynb
"""
tiktoken = import_tiktoken ( )
encoding = tiktoken . encoding_for_model ( openai_model_name )
num_tokens = len ( encoding . encode ( string ) )
return num_tokens
class InfinoCallbackHandler ( BaseCallbackHandler ) :
class InfinoCallbackHandler ( BaseCallbackHandler ) :
""" Callback Handler that logs to Infino. """
""" Callback Handler that logs to Infino. """
@ -32,6 +59,8 @@ class InfinoCallbackHandler(BaseCallbackHandler):
self . model_id = model_id
self . model_id = model_id
self . model_version = model_version
self . model_version = model_version
self . verbose = verbose
self . verbose = verbose
self . is_chat_openai_model = False
self . chat_openai_model_name = " gpt-3.5-turbo "
def _send_to_infino (
def _send_to_infino (
self ,
self ,
@ -97,7 +126,12 @@ class InfinoCallbackHandler(BaseCallbackHandler):
# Track success or error flag.
# Track success or error flag.
self . _send_to_infino ( " error " , self . error )
self . _send_to_infino ( " error " , self . error )
# Track token usage.
# Track prompt response.
for generations in response . generations :
for generation in generations :
self . _send_to_infino ( " prompt_response " , generation . text , is_ts = False )
# Track token usage (for non-chat models).
if ( response . llm_output is not None ) and isinstance ( response . llm_output , Dict ) :
if ( response . llm_output is not None ) and isinstance ( response . llm_output , Dict ) :
token_usage = response . llm_output [ " token_usage " ]
token_usage = response . llm_output [ " token_usage " ]
if token_usage is not None :
if token_usage is not None :
@ -108,10 +142,16 @@ class InfinoCallbackHandler(BaseCallbackHandler):
self . _send_to_infino ( " total_tokens " , total_tokens )
self . _send_to_infino ( " total_tokens " , total_tokens )
self . _send_to_infino ( " completion_tokens " , completion_tokens )
self . _send_to_infino ( " completion_tokens " , completion_tokens )
# Track prompt response.
# Track completion token usage (for openai chat models).
for generations in response . generations :
if self . is_chat_openai_model :
for generation in generations :
messages = " " . join (
self . _send_to_infino ( " prompt_response " , generation . text , is_ts = False )
generation . message . content # type: ignore[attr-defined]
for generation in generations
)
completion_tokens = get_num_tokens (
messages , openai_model_name = self . chat_openai_model_name
)
self . _send_to_infino ( " completion_tokens " , completion_tokens )
def on_llm_error ( self , error : BaseException , * * kwargs : Any ) - > None :
def on_llm_error ( self , error : BaseException , * * kwargs : Any ) - > None :
""" Set the error flag. """
""" Set the error flag. """
@ -165,3 +205,57 @@ class InfinoCallbackHandler(BaseCallbackHandler):
def on_agent_finish ( self , finish : AgentFinish , * * kwargs : Any ) - > None :
def on_agent_finish ( self , finish : AgentFinish , * * kwargs : Any ) - > None :
""" Do nothing. """
""" Do nothing. """
pass
pass
def on_chat_model_start (
self ,
serialized : Dict [ str , Any ] ,
messages : List [ List [ BaseMessage ] ] ,
* * kwargs : Any ,
) - > None :
""" Run when LLM starts running. """
# Currently, for chat models, we only support input prompts for ChatOpenAI.
# Check if this model is a ChatOpenAI model.
values = serialized . get ( " id " )
if values :
for value in values :
if value == " ChatOpenAI " :
self . is_chat_openai_model = True
break
# Track prompt tokens for ChatOpenAI model.
if self . is_chat_openai_model :
invocation_params = kwargs . get ( " invocation_params " )
if invocation_params :
model_name = invocation_params . get ( " model_name " )
if model_name :
self . chat_openai_model_name = model_name
prompt_tokens = 0
for message_list in messages :
message_string = " " . join ( msg . content for msg in message_list )
num_tokens = get_num_tokens (
message_string ,
openai_model_name = self . chat_openai_model_name ,
)
prompt_tokens + = num_tokens
self . _send_to_infino ( " prompt_tokens " , prompt_tokens )
if self . verbose :
print (
f " on_chat_model_start: is_chat_openai_model= \
{ self . is_chat_openai_model } , \
chat_openai_model_name = { self . chat_openai_model_name } "
)
# Send the prompt to infino
prompt = " " . join ( msg . content for sublist in messages for msg in sublist )
self . _send_to_infino ( " prompt " , prompt , is_ts = False )
# Set the error flag to indicate no error (this will get overridden
# in on_llm_error if an error occurs).
self . error = 0
# Set the start time (so that we can calculate the request
# duration in on_llm_end).
self . start_time = time . time ( )