@ -220,7 +220,11 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
def on_chain_start (
def on_chain_start (
self , serialized : Dict [ str , Any ] , inputs : Dict [ str , Any ] , * * kwargs : Any
self , serialized : Dict [ str , Any ] , inputs : Dict [ str , Any ] , * * kwargs : Any
) - > None :
) - > None :
""" Do nothing when LLM chain starts. """
""" If the key `input` is in `inputs`, then save it in `self.prompts` using
either the ` parent_run_id ` or the ` run_id ` as the key . This is done so that
we don ' t log the same input prompt twice, once when the LLM starts and once
when the chain starts .
"""
if " input " in inputs :
if " input " in inputs :
self . prompts . update (
self . prompts . update (
{
{
@ -233,44 +237,55 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
)
)
def on_chain_end ( self , outputs : Dict [ str , Any ] , * * kwargs : Any ) - > None :
def on_chain_end ( self , outputs : Dict [ str , Any ] , * * kwargs : Any ) - > None :
""" Do nothing when LLM chain ends. """
""" If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
prompts = self . prompts [ str ( kwargs [ " parent_run_id " ] or kwargs [ " run_id " ] ) ]
log the outputs to Argilla , and pop the run from ` self . prompts ` . The behavior
if " outputs " in outputs :
differs if the output is a list or not .
# Creates the records and adds them to the `FeedbackDataset`
"""
self . dataset . add_records (
if not any (
records = [
key in self . prompts
{
for key in [ str ( kwargs [ " parent_run_id " ] ) , str ( kwargs [ " run_id " ] ) ]
" fields " : {
) :
" prompt " : prompt ,
return
" response " : output [ " text " ] . strip ( ) ,
prompts = self . prompts . get ( str ( kwargs [ " parent_run_id " ] ) ) or self . prompts . get (
} ,
str ( kwargs [ " run_id " ] )
}
)
for prompt , output in zip ( prompts , outputs [ " outputs " ] )
for chain_output_key , chain_output_val in outputs . items ( ) :
]
if isinstance ( chain_output_val , list ) :
)
# Creates the records and adds them to the `FeedbackDataset`
elif " output " in outputs :
self . dataset . add_records (
# Creates the records and adds them to the `FeedbackDataset`
records = [
self . dataset . add_records (
{
records = [
" fields " : {
{
" prompt " : prompt ,
" fields " : {
" response " : output [ " text " ] . strip ( ) ,
" prompt " : " " . join ( prompts ) ,
} ,
" response " : outputs [ " output " ] . strip ( ) ,
}
} ,
for prompt , output in zip (
}
prompts , chain_output_val # type: ignore
]
)
)
]
else :
)
raise ValueError (
else :
" The `outputs` dictionary did not contain the expected keys `outputs` "
# Creates the records and adds them to the `FeedbackDataset`
" or `output`. "
self . dataset . add_records (
)
records = [
{
" fields " : {
" prompt " : " " . join ( prompts ) , # type: ignore
" response " : chain_output_val . strip ( ) ,
} ,
}
]
)
# Push the records to Argilla
# Push the records to Argilla
self . dataset . push_to_argilla ( )
self . dataset . push_to_argilla ( )
# Pop current run from `self.runs`
# Pop current run from `self.runs`
self . prompts . pop ( str ( kwargs [ " parent_run_id " ] or kwargs [ " run_id " ] ) )
if str ( kwargs [ " parent_run_id " ] ) in self . prompts :
self . prompts . pop ( str ( kwargs [ " parent_run_id " ] ) )
if str ( kwargs [ " run_id " ] ) in self . prompts :
self . prompts . pop ( str ( kwargs [ " run_id " ] ) )
def on_chain_error (
def on_chain_error (
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any