diff --git a/langchain/callbacks/argilla_callback.py b/langchain/callbacks/argilla_callback.py index 8c866d33..1d550461 100644 --- a/langchain/callbacks/argilla_callback.py +++ b/langchain/callbacks/argilla_callback.py @@ -220,7 +220,11 @@ class ArgillaCallbackHandler(BaseCallbackHandler): def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: - """Do nothing when LLM chain starts.""" + """If the key `input` is in `inputs`, then save it in `self.prompts` using + either the `parent_run_id` or the `run_id` as the key. This is done so that + we don't log the same input prompt twice, once when the LLM starts and once + when the chain starts. + """ if "input" in inputs: self.prompts.update( { @@ -233,44 +237,55 @@ class ArgillaCallbackHandler(BaseCallbackHandler): ) def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Do nothing when LLM chain ends.""" - prompts = self.prompts[str(kwargs["parent_run_id"] or kwargs["run_id"])] - if "outputs" in outputs: - # Creates the records and adds them to the `FeedbackDataset` - self.dataset.add_records( - records=[ - { - "fields": { - "prompt": prompt, - "response": output["text"].strip(), - }, - } - for prompt, output in zip(prompts, outputs["outputs"]) - ] - ) - elif "output" in outputs: - # Creates the records and adds them to the `FeedbackDataset` - self.dataset.add_records( - records=[ - { - "fields": { - "prompt": " ".join(prompts), - "response": outputs["output"].strip(), - }, - } - ] - ) - else: - raise ValueError( - "The `outputs` dictionary did not contain the expected keys `outputs` " - "or `output`." - ) + """If either the `parent_run_id` or the `run_id` is in `self.prompts`, then + log the outputs to Argilla, and pop the run from `self.prompts`. The behavior + differs if the output is a list or not. + """ + if not any( + key in self.prompts + for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])] + ): + return + prompts = self.prompts.get(str(kwargs["parent_run_id"])) or self.prompts.get( + str(kwargs["run_id"]) + ) + for chain_output_key, chain_output_val in outputs.items(): + if isinstance(chain_output_val, list): + # Creates the records and adds them to the `FeedbackDataset` + self.dataset.add_records( + records=[ + { + "fields": { + "prompt": prompt, + "response": output["text"].strip(), + }, + } + for prompt, output in zip( + prompts, chain_output_val # type: ignore + ) + ] + ) + else: + # Creates the records and adds them to the `FeedbackDataset` + self.dataset.add_records( + records=[ + { + "fields": { + "prompt": " ".join(prompts), # type: ignore + "response": chain_output_val.strip(), + }, + } + ] + ) # Push the records to Argilla self.dataset.push_to_argilla() # Pop current run from `self.runs` - self.prompts.pop(str(kwargs["parent_run_id"] or kwargs["run_id"])) + if str(kwargs["parent_run_id"]) in self.prompts: + self.prompts.pop(str(kwargs["parent_run_id"])) + if str(kwargs["run_id"]) in self.prompts: + self.prompts.pop(str(kwargs["run_id"])) def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any