Update ArgillaCallbackHandler as of latest argilla release (#9043)

Hi @agola11, or whoever is reviewing this PR 😄 

## What's in this PR?

As of the latest Argilla release, we'll change and refactor some things
to make some workflows easier, one of those is how everything's pushed
to Argilla, so that now there's no need to call `push_to_argilla` over a
`FeedbackDataset` when either `push_to_argilla` is called for the first
time, or `from_argilla` is called; among others.

We also add some class variables to make sure those are easy to update
in case we update those internally in the future, also to make the
`warnings.warn` message lighter from the code view.

P.S. Regarding the Twitter/X mention feel free to do so at either
https://twitter.com/argilla_io or https://twitter.com/alvarobartt, or
both if applicable, otherwise, just the first Twitter/X handle.
This commit is contained in:
Alvaro Bartolome 2023-08-10 19:59:46 +02:00 committed by GitHub
parent 8d351bfc20
commit 08a0741d82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,6 +2,8 @@ import os
import warnings
from typing import Any, Dict, List, Optional, Union
from packaging.version import parse
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -51,6 +53,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
"Argilla, no doubt about it."
"""
REPO_URL = "https://github.com/argilla-io/argilla"
ISSUES_URL = f"{REPO_URL}/issues"
BLOG_URL = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501
DEFAULT_API_URL = "http://localhost:6900"
DEFAULT_API_KEY = "argilla.apikey"
def __init__(
self,
dataset_name: str,
@ -58,23 +67,22 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
api_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> None:
"""Initializes the `ArgillaCallbackHandler`.
f"""Initializes the `ArgillaCallbackHandler`.
Args:
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
exist in advance. If you need help on how to create a `FeedbackDataset`
in Argilla, please visit
https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
in Argilla, please visit {self.BLOG_URL}.
workspace_name: name of the workspace in Argilla where the specified
`FeedbackDataset` lives in. Defaults to `None`, which means that the
default workspace will be used.
api_url: URL of the Argilla Server that we want to use, and where the
`FeedbackDataset` lives in. Defaults to `None`, which means that either
`ARGILLA_API_URL` environment variable or the default
http://localhost:6900 will be used.
`ARGILLA_API_URL` environment variable or `{self.DEFAULT_API_URL}` will
be used.
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
means that either `ARGILLA_API_KEY` environment variable or the default
`argilla.apikey` will be used.
`{self.DEFAULT_API_KEY}` will be used.
Raises:
ImportError: if the `argilla` package is not installed.
@ -87,41 +95,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
try:
import argilla as rg # noqa: F401
self.ARGILLA_VERSION = rg.__version__
except ImportError:
raise ImportError(
"To use the Argilla callback manager you need to have the `argilla` "
"Python package installed. Please install it with `pip install argilla`"
)
# Check whether the Argilla version is compatible
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
raise ImportError(
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
"upgrade `argilla` with `pip install --upgrade argilla`."
)
# Show a warning message if Argilla will assume the default values will be used
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
warnings.warn(
(
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
" set, it will default to `http://localhost:6900`."
f" set, it will default to `{self.DEFAULT_API_URL}`."
),
)
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
warnings.warn(
(
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
" set, it will default to `argilla.apikey`."
f" set, it will default to `{self.DEFAULT_API_KEY}`."
),
)
# Connect to Argilla with the provided credentials, if applicable
try:
rg.init(
api_key=api_key,
api_url=api_url,
)
rg.init(api_key=api_key, api_url=api_url)
except Exception as e:
raise ConnectionError(
f"Could not connect to Argilla with exception: '{e}'.\n"
"Please check your `api_key` and `api_url`, and make sure that "
"the Argilla server is up and running. If the problem persists "
"please report it to https://github.com/argilla-io/argilla/issues "
"with the label `langchain`."
f"please report it to {self.ISSUES_URL} as an `integration` issue."
) from e
# Set the Argilla variables
@ -130,46 +144,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
try:
extra_args = {}
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
warnings.warn(
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
" higher is recommended.",
UserWarning,
)
extra_args = {"with_records": False}
self.dataset = rg.FeedbackDataset.from_argilla(
name=self.dataset_name,
workspace=self.workspace_name,
with_records=False,
**extra_args,
)
except Exception as e:
raise FileNotFoundError(
"`FeedbackDataset` retrieval from Argilla failed with exception:"
f" '{e}'.\nPlease check that the dataset with"
f" name={self.dataset_name} in the"
f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
f"\nPlease check that the dataset with name={self.dataset_name} in the"
f" workspace={self.workspace_name} exists in advance. If you need help"
" on how to create a `langchain`-compatible `FeedbackDataset` in"
" Argilla, please visit"
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501
" If the problem persists please report it to"
" https://github.com/argilla-io/argilla/issues with the label"
" `langchain`."
f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
f" please report it to {self.ISSUES_URL} as an `integration` issue."
) from e
supported_fields = ["prompt", "response"]
if supported_fields != [field.name for field in self.dataset.fields]:
raise ValueError(
f"`FeedbackDataset` with name={self.dataset_name} in the"
f" workspace={self.workspace_name} "
"had fields that are not supported yet for the `langchain` integration."
" Supported fields are: "
f"{supported_fields}, and the current `FeedbackDataset` fields are"
f" {[field.name for field in self.dataset.fields]}. "
"For more information on how to create a `langchain`-compatible"
" `FeedbackDataset` in Argilla, please visit"
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501
f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
f"{self.workspace_name} had fields that are not supported yet for the"
f"`langchain` integration. Supported fields are: {supported_fields},"
f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
" For more information on how to create a `langchain`-compatible"
f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
)
self.prompts: Dict[str, List[str]] = {}
warnings.warn(
(
"The `ArgillaCallbackHandler` is currently in beta and is subject to "
"change based on updates to `langchain`. Please report any issues to "
"https://github.com/argilla-io/argilla/issues with the tag `langchain`."
"The `ArgillaCallbackHandler` is currently in beta and is subject to"
" change based on updates to `langchain`. Please report any issues to"
f" {self.ISSUES_URL} as an `integration` issue."
),
)
@ -205,12 +220,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
]
)
# Push the records to Argilla
self.dataset.push_to_argilla()
# Pop current run from `self.runs`
self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
@ -278,15 +294,16 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
]
)
# Push the records to Argilla
self.dataset.push_to_argilla()
# Pop current run from `self.runs`
if str(kwargs["parent_run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["parent_run_id"]))
if str(kwargs["run_id"]) in self.prompts:
self.prompts.pop(str(kwargs["run_id"]))
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: