diff --git a/libs/langchain/langchain/callbacks/argilla_callback.py b/libs/langchain/langchain/callbacks/argilla_callback.py index 1d550461fa..7ed9af15ba 100644 --- a/libs/langchain/langchain/callbacks/argilla_callback.py +++ b/libs/langchain/langchain/callbacks/argilla_callback.py @@ -2,6 +2,8 @@ import os import warnings from typing import Any, Dict, List, Optional, Union +from packaging.version import parse + from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -51,6 +53,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler): "Argilla, no doubt about it." """ + REPO_URL = "https://github.com/argilla-io/argilla" + ISSUES_URL = f"{REPO_URL}/issues" + BLOG_URL = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501 + + DEFAULT_API_URL = "http://localhost:6900" + DEFAULT_API_KEY = "argilla.apikey" + def __init__( self, dataset_name: str, @@ -58,23 +67,22 @@ class ArgillaCallbackHandler(BaseCallbackHandler): api_url: Optional[str] = None, api_key: Optional[str] = None, ) -> None: - """Initializes the `ArgillaCallbackHandler`. + f"""Initializes the `ArgillaCallbackHandler`. Args: dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must exist in advance. If you need help on how to create a `FeedbackDataset` - in Argilla, please visit - https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html. + in Argilla, please visit {self.BLOG_URL}. workspace_name: name of the workspace in Argilla where the specified `FeedbackDataset` lives in. Defaults to `None`, which means that the default workspace will be used. api_url: URL of the Argilla Server that we want to use, and where the `FeedbackDataset` lives in. Defaults to `None`, which means that either - `ARGILLA_API_URL` environment variable or the default - http://localhost:6900 will be used. + `ARGILLA_API_URL` environment variable or `{self.DEFAULT_API_URL}` will + be used. api_key: API Key to connect to the Argilla Server. Defaults to `None`, which means that either `ARGILLA_API_KEY` environment variable or the default - `argilla.apikey` will be used. + `{self.DEFAULT_API_KEY}` will be used. Raises: ImportError: if the `argilla` package is not installed. @@ -87,41 +95,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler): # Import Argilla (not via `import_argilla` to keep hints in IDEs) try: import argilla as rg # noqa: F401 + + self.ARGILLA_VERSION = rg.__version__ except ImportError: raise ImportError( "To use the Argilla callback manager you need to have the `argilla` " "Python package installed. Please install it with `pip install argilla`" ) + # Check whether the Argilla version is compatible + if parse(self.ARGILLA_VERSION) < parse("1.8.0"): + raise ImportError( + f"The installed `argilla` version is {self.ARGILLA_VERSION} but " + "`ArgillaCallbackHandler` requires at least version 1.8.0. Please " + "upgrade `argilla` with `pip install --upgrade argilla`." + ) + # Show a warning message if Argilla will assume the default values will be used if api_url is None and os.getenv("ARGILLA_API_URL") is None: warnings.warn( ( "Since `api_url` is None, and the env var `ARGILLA_API_URL` is not" - " set, it will default to `http://localhost:6900`." + f" set, it will default to `{self.DEFAULT_API_URL}`." ), ) if api_key is None and os.getenv("ARGILLA_API_KEY") is None: warnings.warn( ( "Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not" - " set, it will default to `argilla.apikey`." + f" set, it will default to `{self.DEFAULT_API_KEY}`." ), ) # Connect to Argilla with the provided credentials, if applicable try: - rg.init( - api_key=api_key, - api_url=api_url, - ) + rg.init(api_key=api_key, api_url=api_url) except Exception as e: raise ConnectionError( f"Could not connect to Argilla with exception: '{e}'.\n" "Please check your `api_key` and `api_url`, and make sure that " "the Argilla server is up and running. If the problem persists " - "please report it to https://github.com/argilla-io/argilla/issues " - "with the label `langchain`." + f"please report it to {self.ISSUES_URL} as an `integration` issue." ) from e # Set the Argilla variables @@ -130,46 +144,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler): # Retrieve the `FeedbackDataset` from Argilla (without existing records) try: + extra_args = {} + if parse(self.ARGILLA_VERSION) < parse("1.14.0"): + warnings.warn( + f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or" + " higher is recommended.", + UserWarning, + ) + extra_args = {"with_records": False} self.dataset = rg.FeedbackDataset.from_argilla( name=self.dataset_name, workspace=self.workspace_name, - with_records=False, + **extra_args, ) except Exception as e: raise FileNotFoundError( - "`FeedbackDataset` retrieval from Argilla failed with exception:" - f" '{e}'.\nPlease check that the dataset with" - f" name={self.dataset_name} in the" + f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`." + f"\nPlease check that the dataset with name={self.dataset_name} in the" f" workspace={self.workspace_name} exists in advance. If you need help" " on how to create a `langchain`-compatible `FeedbackDataset` in" - " Argilla, please visit" - " https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501 - " If the problem persists please report it to" - " https://github.com/argilla-io/argilla/issues with the label" - " `langchain`." + f" Argilla, please visit {self.BLOG_URL}. If the problem persists" + f" please report it to {self.ISSUES_URL} as an `integration` issue." ) from e supported_fields = ["prompt", "response"] if supported_fields != [field.name for field in self.dataset.fields]: raise ValueError( - f"`FeedbackDataset` with name={self.dataset_name} in the" - f" workspace={self.workspace_name} " - "had fields that are not supported yet for the `langchain` integration." - " Supported fields are: " - f"{supported_fields}, and the current `FeedbackDataset` fields are" - f" {[field.name for field in self.dataset.fields]}. " - "For more information on how to create a `langchain`-compatible" - " `FeedbackDataset` in Argilla, please visit" - " https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501 + f"`FeedbackDataset` with name={self.dataset_name} in the workspace=" + f"{self.workspace_name} had fields that are not supported yet for the" + f"`langchain` integration. Supported fields are: {supported_fields}," + f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501 + " For more information on how to create a `langchain`-compatible" + f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}." ) self.prompts: Dict[str, List[str]] = {} warnings.warn( ( - "The `ArgillaCallbackHandler` is currently in beta and is subject to " - "change based on updates to `langchain`. Please report any issues to " - "https://github.com/argilla-io/argilla/issues with the tag `langchain`." + "The `ArgillaCallbackHandler` is currently in beta and is subject to" + " change based on updates to `langchain`. Please report any issues to" + f" {self.ISSUES_URL} as an `integration` issue." ), ) @@ -205,12 +220,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler): ] ) - # Push the records to Argilla - self.dataset.push_to_argilla() - # Pop current run from `self.runs` self.prompts.pop(str(kwargs["run_id"])) + if parse(self.ARGILLA_VERSION) < parse("1.14.0"): + # Push the records to Argilla + self.dataset.push_to_argilla() + def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: @@ -278,15 +294,16 @@ class ArgillaCallbackHandler(BaseCallbackHandler): ] ) - # Push the records to Argilla - self.dataset.push_to_argilla() - # Pop current run from `self.runs` if str(kwargs["parent_run_id"]) in self.prompts: self.prompts.pop(str(kwargs["parent_run_id"])) if str(kwargs["run_id"]) in self.prompts: self.prompts.pop(str(kwargs["run_id"])) + if parse(self.ARGILLA_VERSION) < parse("1.14.0"): + # Push the records to Argilla + self.dataset.push_to_argilla() + def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any ) -> None: