Extend `ArgillaCallbackHandler` support (#6153)

Hi again @agola11! 🤗

## What's in this PR?

After playing around with different chains we noticed that some chains
were using different `output_key`s and we were just handling some, so
we've extended the support to any output, either if it's a Python list
or a string.

Kudos to @dvsrepo for spotting this!

---------

Co-authored-by: Daniel Vila Suero <daniel@argilla.io>
master
Alvaro Bartolome 12 months ago committed by GitHub
parent a8cb9ee013
commit e0dea577ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -220,7 +220,11 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
def on_chain_start( def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None: ) -> None:
"""Do nothing when LLM chain starts.""" """If the key `input` is in `inputs`, then save it in `self.prompts` using
either the `parent_run_id` or the `run_id` as the key. This is done so that
we don't log the same input prompt twice, once when the LLM starts and once
when the chain starts.
"""
if "input" in inputs: if "input" in inputs:
self.prompts.update( self.prompts.update(
{ {
@ -233,44 +237,55 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
) )
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing when LLM chain ends.""" """If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
prompts = self.prompts[str(kwargs["parent_run_id"] or kwargs["run_id"])] log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
if "outputs" in outputs: differs if the output is a list or not.
# Creates the records and adds them to the `FeedbackDataset` """
self.dataset.add_records( if not any(
records=[ key in self.prompts
{ for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
"fields": { ):
"prompt": prompt, return
"response": output["text"].strip(), prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
}, str(kwargs["run_id"])
} )
for prompt, output in zip(prompts, outputs["outputs"]) for chain_output_key, chain_output_val in outputs.items():
] if isinstance(chain_output_val, list):
) # Creates the records and adds them to the `FeedbackDataset`
elif "output" in outputs: self.dataset.add_records(
# Creates the records and adds them to the `FeedbackDataset` records=[
self.dataset.add_records( {
records=[ "fields": {
{ "prompt": prompt,
"fields": { "response": output["text"].strip(),
"prompt": " ".join(prompts), },
"response": outputs["output"].strip(), }
}, for prompt, output in zip(
} prompts, chain_output_val # type: ignore
] )
) ]
else: )
raise ValueError( else:
"The `outputs` dictionary did not contain the expected keys `outputs` " # Creates the records and adds them to the `FeedbackDataset`
"or `output`." self.dataset.add_records(
) records=[
{
"fields": {
"prompt": " ".join(prompts), # type: ignore
"response": chain_output_val.strip(),
},
}
]
)
# Push the records to Argilla # Push the records to Argilla
self.dataset.push_to_argilla() self.dataset.push_to_argilla()
# Pop current run from `self.runs` # Pop current run from `self.runs`
self.prompts.pop(str(kwargs["parent_run_id"] or kwargs["run_id"])) if str(kwargs["parent_run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["parent_run_id"]))
if str(kwargs["run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["run_id"]))
def on_chain_error( def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any

Loading…
Cancel
Save