Extend ArgillaCallbackHandler support (#6153)

Hi again @agola11! 🤗

## What's in this PR?

After playing around with different chains we noticed that some chains
were using different `output_key`s and we were just handling some, so
we've extended the support to any output, either if it's a Python list
or a string.

Kudos to @dvsrepo for spotting this!

---------

Co-authored-by: Daniel Vila Suero <daniel@argilla.io>
This commit is contained in:
Alvaro Bartolome 2023-06-18 20:18:33 +02:00 committed by GitHub
parent a8cb9ee013
commit e0dea577ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -220,7 +220,11 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Do nothing when LLM chain starts."""
"""If the key `input` is in `inputs`, then save it in `self.prompts` using
either the `parent_run_id` or the `run_id` as the key. This is done so that
we don't log the same input prompt twice, once when the LLM starts and once
when the chain starts.
"""
if "input" in inputs:
self.prompts.update(
{
@ -233,9 +237,20 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing when LLM chain ends."""
prompts = self.prompts[str(kwargs["parent_run_id"] or kwargs["run_id"])]
if "outputs" in outputs:
"""If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
differs if the output is a list or not.
"""
if not any(
key in self.prompts
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
):
return
prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
str(kwargs["run_id"])
)
for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, list):
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
@ -245,32 +260,32 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
"response": output["text"].strip(),
},
}
for prompt, output in zip(prompts, outputs["outputs"])
for prompt, output in zip(
prompts, chain_output_val # type: ignore
)
]
)
elif "output" in outputs:
else:
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": " ".join(prompts),
"response": outputs["output"].strip(),
"prompt": " ".join(prompts), # type: ignore
"response": chain_output_val.strip(),
},
}
]
)
else:
raise ValueError(
"The `outputs` dictionary did not contain the expected keys `outputs` "
"or `output`."
)
# Push the records to Argilla
self.dataset.push_to_argilla()
# Pop current run from `self.runs`
self.prompts.pop(str(kwargs["parent_run_id"] or kwargs["run_id"]))
if str(kwargs["parent_run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["parent_run_id"]))
if str(kwargs["run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["run_id"]))
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any