@ -220,7 +220,11 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
def on_chain_start (
self , serialized : Dict [ str , Any ] , inputs : Dict [ str , Any ] , * * kwargs : Any
) - > None :
""" Do nothing when LLM chain starts. """
""" If the key `input` is in `inputs`, then save it in `self.prompts` using
either the ` parent_run_id ` or the ` run_id ` as the key . This is done so that
we don ' t log the same input prompt twice, once when the LLM starts and once
when the chain starts .
"""
if " input " in inputs :
self . prompts . update (
{
@ -233,44 +237,55 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
)
def on_chain_end ( self , outputs : Dict [ str , Any ] , * * kwargs : Any ) - > None :
""" Do nothing when LLM chain ends. """
prompts = self . prompts [ str ( kwargs [ " parent_run_id " ] or kwargs [ " run_id " ] ) ]
if " outputs " in outputs :
# Creates the records and adds them to the `FeedbackDataset`
self . dataset . add_records (
records = [
{
" fields " : {
" prompt " : prompt ,
" response " : output [ " text " ] . strip ( ) ,
} ,
}
for prompt , output in zip ( prompts , outputs [ " outputs " ] )
]
)
elif " output " in outputs :
# Creates the records and adds them to the `FeedbackDataset`
self . dataset . add_records (
records = [
{
" fields " : {
" prompt " : " " . join ( prompts ) ,
" response " : outputs [ " output " ] . strip ( ) ,
} ,
}
]
)
else :
raise ValueError (
" The `outputs` dictionary did not contain the expected keys `outputs` "
" or `output`. "
)
""" If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
log the outputs to Argilla , and pop the run from ` self . prompts ` . The behavior
differs if the output is a list or not .
"""
if not any (
key in self . prompts
for key in [ str ( kwargs [ " parent_run_id " ] ) , str ( kwargs [ " run_id " ] ) ]
) :
return
prompts = self . prompts . get ( str ( kwargs [ " parent_run_id " ] ) ) or self . prompts . get (
str ( kwargs [ " run_id " ] )
)
for chain_output_key , chain_output_val in outputs . items ( ) :
if isinstance ( chain_output_val , list ) :
# Creates the records and adds them to the `FeedbackDataset`
self . dataset . add_records (
records = [
{
" fields " : {
" prompt " : prompt ,
" response " : output [ " text " ] . strip ( ) ,
} ,
}
for prompt , output in zip (
prompts , chain_output_val # type: ignore
)
]
)
else :
# Creates the records and adds them to the `FeedbackDataset`
self . dataset . add_records (
records = [
{
" fields " : {
" prompt " : " " . join ( prompts ) , # type: ignore
" response " : chain_output_val . strip ( ) ,
} ,
}
]
)
# Push the records to Argilla
self . dataset . push_to_argilla ( )
# Pop current run from `self.runs`
self . prompts . pop ( str ( kwargs [ " parent_run_id " ] or kwargs [ " run_id " ] ) )
if str ( kwargs [ " parent_run_id " ] ) in self . prompts :
self . prompts . pop ( str ( kwargs [ " parent_run_id " ] ) )
if str ( kwargs [ " run_id " ] ) in self . prompts :
self . prompts . pop ( str ( kwargs [ " run_id " ] ) )
def on_chain_error (
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any