mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Extend ArgillaCallbackHandler
support (#6153)
Hi again @agola11! 🤗
## What's in this PR?
After playing around with different chains we noticed that some chains
were using different `output_key`s and we were just handling some, so
we've extended the support to any output, either if it's a Python list
or a string.
Kudos to @dvsrepo for spotting this!
---------
Co-authored-by: Daniel Vila Suero <daniel@argilla.io>
This commit is contained in:
parent
a8cb9ee013
commit
e0dea577ee
@ -220,7 +220,11 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM chain starts."""
|
||||
"""If the key `input` is in `inputs`, then save it in `self.prompts` using
|
||||
either the `parent_run_id` or the `run_id` as the key. This is done so that
|
||||
we don't log the same input prompt twice, once when the LLM starts and once
|
||||
when the chain starts.
|
||||
"""
|
||||
if "input" in inputs:
|
||||
self.prompts.update(
|
||||
{
|
||||
@ -233,9 +237,20 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Do nothing when LLM chain ends."""
|
||||
prompts = self.prompts[str(kwargs["parent_run_id"] or kwargs["run_id"])]
|
||||
if "outputs" in outputs:
|
||||
"""If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
|
||||
log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
|
||||
differs if the output is a list or not.
|
||||
"""
|
||||
if not any(
|
||||
key in self.prompts
|
||||
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
|
||||
):
|
||||
return
|
||||
prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get(
|
||||
str(kwargs["run_id"])
|
||||
)
|
||||
for chain_output_key, chain_output_val in outputs.items():
|
||||
if isinstance(chain_output_val, list):
|
||||
# Creates the records and adds them to the `FeedbackDataset`
|
||||
self.dataset.add_records(
|
||||
records=[
|
||||
@ -245,32 +260,32 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
"response": output["text"].strip(),
|
||||
},
|
||||
}
|
||||
for prompt, output in zip(prompts, outputs["outputs"])
|
||||
for prompt, output in zip(
|
||||
prompts, chain_output_val # type: ignore
|
||||
)
|
||||
]
|
||||
)
|
||||
elif "output" in outputs:
|
||||
else:
|
||||
# Creates the records and adds them to the `FeedbackDataset`
|
||||
self.dataset.add_records(
|
||||
records=[
|
||||
{
|
||||
"fields": {
|
||||
"prompt": " ".join(prompts),
|
||||
"response": outputs["output"].strip(),
|
||||
"prompt": " ".join(prompts), # type: ignore
|
||||
"response": chain_output_val.strip(),
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The `outputs` dictionary did not contain the expected keys `outputs` "
|
||||
"or `output`."
|
||||
)
|
||||
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
# Pop current run from `self.runs`
|
||||
self.prompts.pop(str(kwargs["parent_run_id"] or kwargs["run_id"]))
|
||||
if str(kwargs["parent_run_id"]) in self.prompts:
|
||||
self.prompts.pop(str(kwargs["parent_run_id"]))
|
||||
if str(kwargs["run_id"]) in self.prompts:
|
||||
self.prompts.pop(str(kwargs["run_id"]))
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
|
Loading…
Reference in New Issue
Block a user