mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
Update ArgillaCallbackHandler
as of latest argilla
release (#9043)
Hi @agola11, or whoever is reviewing this PR 😄 ## What's in this PR? As of the latest Argilla release, we'll change and refactor some things to make some workflows easier, one of those is how everything's pushed to Argilla, so that now there's no need to call `push_to_argilla` over a `FeedbackDataset` when either `push_to_argilla` is called for the first time, or `from_argilla` is called; among others. We also add some class variables to make sure those are easy to update in case we update those internally in the future, also to make the `warnings.warn` message lighter from the code view. P.S. Regarding the Twitter/X mention feel free to do so at either https://twitter.com/argilla_io or https://twitter.com/alvarobartt, or both if applicable, otherwise, just the first Twitter/X handle.
This commit is contained in:
parent
8d351bfc20
commit
08a0741d82
@ -2,6 +2,8 @@ import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from packaging.version import parse
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
@ -51,6 +53,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
"Argilla, no doubt about it."
|
||||
"""
|
||||
|
||||
REPO_URL = "https://github.com/argilla-io/argilla"
|
||||
ISSUES_URL = f"{REPO_URL}/issues"
|
||||
BLOG_URL = "https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html" # noqa: E501
|
||||
|
||||
DEFAULT_API_URL = "http://localhost:6900"
|
||||
DEFAULT_API_KEY = "argilla.apikey"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name: str,
|
||||
@ -58,23 +67,22 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initializes the `ArgillaCallbackHandler`.
|
||||
f"""Initializes the `ArgillaCallbackHandler`.
|
||||
|
||||
Args:
|
||||
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
|
||||
exist in advance. If you need help on how to create a `FeedbackDataset`
|
||||
in Argilla, please visit
|
||||
https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
|
||||
in Argilla, please visit {self.BLOG_URL}.
|
||||
workspace_name: name of the workspace in Argilla where the specified
|
||||
`FeedbackDataset` lives in. Defaults to `None`, which means that the
|
||||
default workspace will be used.
|
||||
api_url: URL of the Argilla Server that we want to use, and where the
|
||||
`FeedbackDataset` lives in. Defaults to `None`, which means that either
|
||||
`ARGILLA_API_URL` environment variable or the default
|
||||
http://localhost:6900 will be used.
|
||||
`ARGILLA_API_URL` environment variable or `{self.DEFAULT_API_URL}` will
|
||||
be used.
|
||||
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
|
||||
means that either `ARGILLA_API_KEY` environment variable or the default
|
||||
`argilla.apikey` will be used.
|
||||
`{self.DEFAULT_API_KEY}` will be used.
|
||||
|
||||
Raises:
|
||||
ImportError: if the `argilla` package is not installed.
|
||||
@ -87,41 +95,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
|
||||
try:
|
||||
import argilla as rg # noqa: F401
|
||||
|
||||
self.ARGILLA_VERSION = rg.__version__
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use the Argilla callback manager you need to have the `argilla` "
|
||||
"Python package installed. Please install it with `pip install argilla`"
|
||||
)
|
||||
|
||||
# Check whether the Argilla version is compatible
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
|
||||
raise ImportError(
|
||||
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
|
||||
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
|
||||
"upgrade `argilla` with `pip install --upgrade argilla`."
|
||||
)
|
||||
|
||||
# Show a warning message if Argilla will assume the default values will be used
|
||||
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
|
||||
warnings.warn(
|
||||
(
|
||||
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
|
||||
" set, it will default to `http://localhost:6900`."
|
||||
f" set, it will default to `{self.DEFAULT_API_URL}`."
|
||||
),
|
||||
)
|
||||
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
|
||||
warnings.warn(
|
||||
(
|
||||
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
|
||||
" set, it will default to `argilla.apikey`."
|
||||
f" set, it will default to `{self.DEFAULT_API_KEY}`."
|
||||
),
|
||||
)
|
||||
|
||||
# Connect to Argilla with the provided credentials, if applicable
|
||||
try:
|
||||
rg.init(
|
||||
api_key=api_key,
|
||||
api_url=api_url,
|
||||
)
|
||||
rg.init(api_key=api_key, api_url=api_url)
|
||||
except Exception as e:
|
||||
raise ConnectionError(
|
||||
f"Could not connect to Argilla with exception: '{e}'.\n"
|
||||
"Please check your `api_key` and `api_url`, and make sure that "
|
||||
"the Argilla server is up and running. If the problem persists "
|
||||
"please report it to https://github.com/argilla-io/argilla/issues "
|
||||
"with the label `langchain`."
|
||||
f"please report it to {self.ISSUES_URL} as an `integration` issue."
|
||||
) from e
|
||||
|
||||
# Set the Argilla variables
|
||||
@ -130,46 +144,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
|
||||
try:
|
||||
extra_args = {}
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||
warnings.warn(
|
||||
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
|
||||
" higher is recommended.",
|
||||
UserWarning,
|
||||
)
|
||||
extra_args = {"with_records": False}
|
||||
self.dataset = rg.FeedbackDataset.from_argilla(
|
||||
name=self.dataset_name,
|
||||
workspace=self.workspace_name,
|
||||
with_records=False,
|
||||
**extra_args,
|
||||
)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
"`FeedbackDataset` retrieval from Argilla failed with exception:"
|
||||
f" '{e}'.\nPlease check that the dataset with"
|
||||
f" name={self.dataset_name} in the"
|
||||
f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
|
||||
f"\nPlease check that the dataset with name={self.dataset_name} in the"
|
||||
f" workspace={self.workspace_name} exists in advance. If you need help"
|
||||
" on how to create a `langchain`-compatible `FeedbackDataset` in"
|
||||
" Argilla, please visit"
|
||||
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501
|
||||
" If the problem persists please report it to"
|
||||
" https://github.com/argilla-io/argilla/issues with the label"
|
||||
" `langchain`."
|
||||
f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
|
||||
f" please report it to {self.ISSUES_URL} as an `integration` issue."
|
||||
) from e
|
||||
|
||||
supported_fields = ["prompt", "response"]
|
||||
if supported_fields != [field.name for field in self.dataset.fields]:
|
||||
raise ValueError(
|
||||
f"`FeedbackDataset` with name={self.dataset_name} in the"
|
||||
f" workspace={self.workspace_name} "
|
||||
"had fields that are not supported yet for the `langchain` integration."
|
||||
" Supported fields are: "
|
||||
f"{supported_fields}, and the current `FeedbackDataset` fields are"
|
||||
f" {[field.name for field in self.dataset.fields]}. "
|
||||
"For more information on how to create a `langchain`-compatible"
|
||||
" `FeedbackDataset` in Argilla, please visit"
|
||||
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501
|
||||
f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
|
||||
f"{self.workspace_name} had fields that are not supported yet for the"
|
||||
f"`langchain` integration. Supported fields are: {supported_fields},"
|
||||
f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
|
||||
" For more information on how to create a `langchain`-compatible"
|
||||
f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
|
||||
)
|
||||
|
||||
self.prompts: Dict[str, List[str]] = {}
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"The `ArgillaCallbackHandler` is currently in beta and is subject to "
|
||||
"change based on updates to `langchain`. Please report any issues to "
|
||||
"https://github.com/argilla-io/argilla/issues with the tag `langchain`."
|
||||
"The `ArgillaCallbackHandler` is currently in beta and is subject to"
|
||||
" change based on updates to `langchain`. Please report any issues to"
|
||||
f" {self.ISSUES_URL} as an `integration` issue."
|
||||
),
|
||||
)
|
||||
|
||||
@ -205,12 +220,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
]
|
||||
)
|
||||
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
# Pop current run from `self.runs`
|
||||
self.prompts.pop(str(kwargs["run_id"]))
|
||||
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
@ -278,15 +294,16 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
||||
]
|
||||
)
|
||||
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
# Pop current run from `self.runs`
|
||||
if str(kwargs["parent_run_id"]) in self.prompts:
|
||||
self.prompts.pop(str(kwargs["parent_run_id"]))
|
||||
if str(kwargs["run_id"]) in self.prompts:
|
||||
self.prompts.pop(str(kwargs["run_id"]))
|
||||
|
||||
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
||||
# Push the records to Argilla
|
||||
self.dataset.push_to_argilla()
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user