@ -1,10 +1,9 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING , Any , Dict, List , Optional, Union
from typing import TYPE_CHECKING , Any , Optional
from langchain . callbacks . base import BaseCallbackHandler
from langchain . schema import AgentAction , AgentFinish , Generation , LLMResult
from langchain . utils import get_from_env
if TYPE_CHECKING :
@ -91,99 +90,29 @@ class WhyLabsCallbackHandler(BaseCallbackHandler):
themes ( bool ) : Whether to enable theme analysis . Defaults to False .
"""
def __init__ ( self , logger : Logger ):
""" Initiate the rolling logger """
def __init__ ( self , logger : Logger , handler : Any ):
""" Initiate the rolling logger . """
super ( ) . __init__ ( )
self . logger = logger
diagnostic_logger . info (
" Initialized WhyLabs callback handler with configured whylogs Logger. "
)
def _profile_generations ( self , generations : List [ Generation ] ) - > None :
for gen in generations :
self . logger . log ( { " response " : gen . text } )
def on_llm_start (
self , serialized : Dict [ str , Any ] , prompts : List [ str ] , * * kwargs : Any
) - > None :
""" Pass the input prompts to the logger """
for prompt in prompts :
self . logger . log ( { " prompt " : prompt } )
def on_llm_end ( self , response : LLMResult , * * kwargs : Any ) - > None :
""" Pass the generated response to the logger. """
for generations in response . generations :
self . _profile_generations ( generations )
def on_llm_new_token ( self , token : str , * * kwargs : Any ) - > None :
""" Do nothing. """
pass
def on_llm_error (
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
) - > None :
""" Do nothing. """
pass
def on_chain_start (
self , serialized : Dict [ str , Any ] , inputs : Dict [ str , Any ] , * * kwargs : Any
) - > None :
""" Do nothing. """
def on_chain_end ( self , outputs : Dict [ str , Any ] , * * kwargs : Any ) - > None :
""" Do nothing. """
def on_chain_error (
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
) - > None :
""" Do nothing. """
pass
def on_tool_start (
self ,
serialized : Dict [ str , Any ] ,
input_str : str ,
* * kwargs : Any ,
) - > None :
""" Do nothing. """
def on_agent_action (
self , action : AgentAction , color : Optional [ str ] = None , * * kwargs : Any
) - > Any :
""" Do nothing. """
def on_tool_end (
self ,
output : str ,
color : Optional [ str ] = None ,
observation_prefix : Optional [ str ] = None ,
llm_prefix : Optional [ str ] = None ,
* * kwargs : Any ,
) - > None :
""" Do nothing. """
def on_tool_error (
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
) - > None :
""" Do nothing. """
pass
def on_text ( self , text : str , * * kwargs : Any ) - > None :
""" Do nothing. """
def on_agent_finish (
self , finish : AgentFinish , color : Optional [ str ] = None , * * kwargs : Any
) - > None :
""" Run on agent end. """
pass
if hasattr ( handler , " init " ) :
handler . init ( self )
if hasattr ( handler , " _get_callbacks " ) :
self . _callbacks = handler . _get_callbacks ( )
else :
self . _callbacks = dict ( )
diagnostic_logger . warning ( " initialized handler without callbacks. " )
self . _logger = logger
def flush ( self ) - > None :
self . logger . _do_rollover ( )
diagnostic_logger . info ( " Flushing WhyLabs logger, writing profile... " )
""" Explicitly write current profile if using a rolling logger. """
if self . _logger and hasattr ( self . _logger , " _do_rollover " ) :
self . _logger . _do_rollover ( )
diagnostic_logger . info ( " Flushing WhyLabs logger, writing profile... " )
def close ( self ) - > None :
self . logger . close ( )
diagnostic_logger . info ( " Closing WhyLabs logger, see you next time! " )
""" Close any loggers to allow writing out of any profiles before exiting. """
if self . _logger and hasattr ( self . _logger , " close " ) :
self . _logger . close ( )
diagnostic_logger . info ( " Closing WhyLabs logger, see you next time! " )
def __enter__ ( self ) - > WhyLabsCallbackHandler :
return self
@ -203,7 +132,8 @@ class WhyLabsCallbackHandler(BaseCallbackHandler):
sentiment : bool = False ,
toxicity : bool = False ,
themes : bool = False ,
) - > Logger :
logger : Optional [ Logger ] = None ,
) - > WhyLabsCallbackHandler :
""" Instantiate whylogs Logger from params.
Args :
@ -224,31 +154,39 @@ class WhyLabsCallbackHandler(BaseCallbackHandler):
themes ( bool ) : If True will initialize a model to calculate
distance to configured themes . Defaults to None and will not gather this
metric .
logger ( Optional [ Logger ] ) : If specified will bind the configured logger as
the telemetry gathering agent . Defaults to LangKit schema with periodic
WhyLabs writer .
"""
# langkit library will import necessary whylogs libraries
import_langkit ( sentiment = sentiment , toxicity = toxicity , themes = themes )
import whylogs as why
from langkit . callback_handler import get_callback_instance
from whylogs . api . writer . whylabs import WhyLabsWriter
from whylogs . core . schema import DeclarativeSchema
from whylogs . experimental . core . metrics . udf_metric import generate_udf_schema
api_key = api_key or get_from_env ( " api_key " , " WHYLABS_API_KEY " )
org_id = org_id or get_from_env ( " org_id " , " WHYLABS_DEFAULT_ORG_ID " )
dataset_id = dataset_id or get_from_env (
" dataset_id " , " WHYLABS_DEFAULT_DATASET_ID "
)
whylabs_writer = WhyLabsWriter (
api_key = api_key , org_id = org_id , dataset_id = dataset_id
)
langkit_schema = DeclarativeSchema ( generate_udf_schema ( ) )
whylabs_logger = why . logger (
mode = " rolling " , interval = 5 , when = " M " , schema = langkit_schema
)
whylabs_logger . append_writer ( writer = whylabs_writer )
from whylogs . experimental . core . udf_schema import udf_schema
if logger is None :
api_key = api_key or get_from_env ( " api_key " , " WHYLABS_API_KEY " )
org_id = org_id or get_from_env ( " org_id " , " WHYLABS_DEFAULT_ORG_ID " )
dataset_id = dataset_id or get_from_env (
" dataset_id " , " WHYLABS_DEFAULT_DATASET_ID "
)
whylabs_writer = WhyLabsWriter (
api_key = api_key , org_id = org_id , dataset_id = dataset_id
)
whylabs_logger = why . logger (
mode = " rolling " , interval = 5 , when = " M " , schema = udf_schema ( )
)
whylabs_logger . append_writer ( writer = whylabs_writer )
else :
diagnostic_logger . info ( " Using passed in whylogs logger {logger} " )
whylabs_logger = logger
callback_handler_cls = get_callback_instance ( logger = whylabs_logger , impl = cls )
diagnostic_logger . info (
" Started whylogs Logger with WhyLabsWriter and initialized LangKit. 📝 "
)
return cls ( whylabs_logger )
return callback_handler_cls