@ -2,6 +2,8 @@ import os
import warnings
from typing import Any , Dict , List , Optional , Union
from packaging . version import parse
from langchain . callbacks . base import BaseCallbackHandler
from langchain . schema import AgentAction , AgentFinish , LLMResult
@ -51,6 +53,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
" Argilla, no doubt about it. "
"""
REPO_URL = " https://github.com/argilla-io/argilla "
ISSUES_URL = f " { REPO_URL } /issues "
BLOG_URL = " https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html " # noqa: E501
DEFAULT_API_URL = " http://localhost:6900 "
DEFAULT_API_KEY = " argilla.apikey "
def __init__ (
self ,
dataset_name : str ,
@ -58,23 +67,22 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
api_url : Optional [ str ] = None ,
api_key : Optional [ str ] = None ,
) - > None :
""" Initializes the `ArgillaCallbackHandler`.
f """ Initializes the `ArgillaCallbackHandler`.
Args :
dataset_name : name of the ` FeedbackDataset ` in Argilla . Note that it must
exist in advance . If you need help on how to create a ` FeedbackDataset `
in Argilla , please visit
https : / / docs . argilla . io / en / latest / guides / llms / practical_guides / use_argilla_callback_in_langchain . html .
in Argilla , please visit { self . BLOG_URL } .
workspace_name : name of the workspace in Argilla where the specified
` FeedbackDataset ` lives in . Defaults to ` None ` , which means that the
default workspace will be used .
api_url : URL of the Argilla Server that we want to use , and where the
` FeedbackDataset ` lives in . Defaults to ` None ` , which means that either
` ARGILLA_API_URL ` environment variable or the default
http: / / localhost : 6900 will be used .
` ARGILLA_API_URL ` environment variable or ` { self . DEFAULT_API_URL } ` will
be used .
api_key : API Key to connect to the Argilla Server . Defaults to ` None ` , which
means that either ` ARGILLA_API_KEY ` environment variable or the default
` argilla . apikey ` will be used .
` { self . DEFAULT_API_KEY } ` will be used .
Raises :
ImportError : if the ` argilla ` package is not installed .
@ -87,41 +95,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
try :
import argilla as rg # noqa: F401
self . ARGILLA_VERSION = rg . __version__
except ImportError :
raise ImportError (
" To use the Argilla callback manager you need to have the `argilla` "
" Python package installed. Please install it with `pip install argilla` "
)
# Check whether the Argilla version is compatible
if parse ( self . ARGILLA_VERSION ) < parse ( " 1.8.0 " ) :
raise ImportError (
f " The installed `argilla` version is { self . ARGILLA_VERSION } but "
" `ArgillaCallbackHandler` requires at least version 1.8.0. Please "
" upgrade `argilla` with `pip install --upgrade argilla`. "
)
# Show a warning message if Argilla will assume the default values will be used
if api_url is None and os . getenv ( " ARGILLA_API_URL " ) is None :
warnings . warn (
(
" Since `api_url` is None, and the env var `ARGILLA_API_URL` is not "
" set, it will default to `http://localhost:6900`. "
f" set, it will default to ` { self . DEFAULT_API_URL } `."
) ,
)
if api_key is None and os . getenv ( " ARGILLA_API_KEY " ) is None :
warnings . warn (
(
" Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not "
" set, it will default to `argilla.apikey `."
f" set, it will default to ` { self . DEFAULT_API_KEY } `."
) ,
)
# Connect to Argilla with the provided credentials, if applicable
try :
rg . init (
api_key = api_key ,
api_url = api_url ,
)
rg . init ( api_key = api_key , api_url = api_url )
except Exception as e :
raise ConnectionError (
f " Could not connect to Argilla with exception: ' { e } ' . \n "
" Please check your `api_key` and `api_url`, and make sure that "
" the Argilla server is up and running. If the problem persists "
" please report it to https://github.com/argilla-io/argilla/issues "
" with the label `langchain`. "
f " please report it to { self . ISSUES_URL } as an `integration` issue. "
) from e
# Set the Argilla variables
@ -130,46 +144,47 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
try :
extra_args = { }
if parse ( self . ARGILLA_VERSION ) < parse ( " 1.14.0 " ) :
warnings . warn (
f " You have Argilla { self . ARGILLA_VERSION } , but Argilla 1.14.0 or "
" higher is recommended. " ,
UserWarning ,
)
extra_args = { " with_records " : False }
self . dataset = rg . FeedbackDataset . from_argilla (
name = self . dataset_name ,
workspace = self . workspace_name ,
with_records = False ,
* * extra_args ,
)
except Exception as e :
raise FileNotFoundError (
" `FeedbackDataset` retrieval from Argilla failed with exception: "
f " ' { e } ' . \n Please check that the dataset with "
f " name= { self . dataset_name } in the "
f " `FeedbackDataset` retrieval from Argilla failed with exception ` { e } `. "
f " \n Please check that the dataset with name= { self . dataset_name } in the "
f " workspace= { self . workspace_name } exists in advance. If you need help "
" on how to create a `langchain`-compatible `FeedbackDataset` in "
" Argilla, please visit "
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html. " # noqa: E501
" If the problem persists please report it to "
" https://github.com/argilla-io/argilla/issues with the label "
" `langchain`. "
f " Argilla, please visit { self . BLOG_URL } . If the problem persists "
f " please report it to { self . ISSUES_URL } as an `integration` issue. "
) from e
supported_fields = [ " prompt " , " response " ]
if supported_fields != [ field . name for field in self . dataset . fields ] :
raise ValueError (
f " `FeedbackDataset` with name= { self . dataset_name } in the "
f " workspace= { self . workspace_name } "
" had fields that are not supported yet for the `langchain` integration. "
" Supported fields are: "
f " { supported_fields } , and the current `FeedbackDataset` fields are "
f " { [ field . name for field in self . dataset . fields ] } . "
" For more information on how to create a `langchain`-compatible "
" `FeedbackDataset` in Argilla, please visit "
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html. " # noqa: E501
f " `FeedbackDataset` with name= { self . dataset_name } in the workspace= "
f " { self . workspace_name } had fields that are not supported yet for the "
f " `langchain` integration. Supported fields are: { supported_fields } , "
f " and the current `FeedbackDataset` fields are { [ field . name for field in self . dataset . fields ] } . " # noqa: E501
" For more information on how to create a `langchain`-compatible "
f " `FeedbackDataset` in Argilla, please visit { self . BLOG_URL } . "
)
self . prompts : Dict [ str , List [ str ] ] = { }
warnings . warn (
(
" The `ArgillaCallbackHandler` is currently in beta and is subject to "
" change based on updates to `langchain`. Please report any issues to "
" https://github.com/argilla-io/argilla/issues with the tag `langchain` ."
" The `ArgillaCallbackHandler` is currently in beta and is subject to "
" change based on updates to `langchain`. Please report any issues to"
f" { self . ISSUES_URL } as an `integration` issue ."
) ,
)
@ -205,12 +220,13 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
]
)
# Push the records to Argilla
self . dataset . push_to_argilla ( )
# Pop current run from `self.runs`
self . prompts . pop ( str ( kwargs [ " run_id " ] ) )
if parse ( self . ARGILLA_VERSION ) < parse ( " 1.14.0 " ) :
# Push the records to Argilla
self . dataset . push_to_argilla ( )
def on_llm_error (
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
) - > None :
@ -278,15 +294,16 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
]
)
# Push the records to Argilla
self . dataset . push_to_argilla ( )
# Pop current run from `self.runs`
if str ( kwargs [ " parent_run_id " ] ) in self . prompts :
self . prompts . pop ( str ( kwargs [ " parent_run_id " ] ) )
if str ( kwargs [ " run_id " ] ) in self . prompts :
self . prompts . pop ( str ( kwargs [ " run_id " ] ) )
if parse ( self . ARGILLA_VERSION ) < parse ( " 1.14.0 " ) :
# Push the records to Argilla
self . dataset . push_to_argilla ( )
def on_chain_error (
self , error : Union [ Exception , KeyboardInterrupt ] , * * kwargs : Any
) - > None :