Extend `ArgillaCallbackHandler` support (#6153)

Hi again @agola11! 🤗

## What's in this PR?

After playing around with different chains we noticed that some chains
were using different `output_key`s and we were just handling some, so
we've extended the support to any output, either if it's a Python list
or a string.

Kudos to @dvsrepo for spotting this!

---------

Co-authored-by: Daniel Vila Suero <daniel@argilla.io>
master
Alvaro Bartolome 11 months ago committed by GitHub
parent a8cb9ee013
commit e0dea577ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -220,7 +220,11 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Do nothing when LLM chain starts."""
"""If the key `input` is in `inputs`, then save it in `self.prompts` using
either the `parent_run_id` or the `run_id` as the key. This is done so that
we don't log the same input prompt twice, once when the LLM starts and once
when the chain starts.
"""
if "input" in inputs:
self.prompts.update(
{
@ -233,44 +237,55 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing when LLM chain ends."""
prompts = self.prompts[str(kwargs["parent_run_id"] or kwargs["run_id"])]
if "outputs" in outputs:
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": output["text"].strip(),
},
}
for prompt, output in zip(prompts, outputs["outputs"])
]
)
elif "output" in outputs:
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": " ".join(prompts),
"response": outputs["output"].strip(),
},
}
]
)
else:
raise ValueError(
"The `outputs` dictionary did not contain the expected keys `outputs` "
"or `output`."
)
"""If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
differs if the output is a list or not.
"""
if not any(
key in self.prompts
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
):
return
prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
str(kwargs["run_id"])
)
for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, list):
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": output["text"].strip(),
},
}
for prompt, output in zip(
prompts, chain_output_val # type: ignore
)
]
)
else:
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": " ".join(prompts), # type: ignore
"response": chain_output_val.strip(),
},
}
]
)
# Push the records to Argilla
self.dataset.push_to_argilla()
# Pop current run from `self.runs`
self.prompts.pop(str(kwargs["parent_run_id"] or kwargs["run_id"]))
if str(kwargs["parent_run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["parent_run_id"]))
if str(kwargs["run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["run_id"]))
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any

Loading…
Cancel
Save