mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
43db4cd20e
This PR updates the on_tool_end handlers to return the raw output from the tool instead of casting it to a string. This is technically a breaking change, though it's impact is expected to be somewhat minimal. It will fix behavior in `astream_events` as well. Fixes the following issue #18760 raised by @eyurtsev --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
350 lines
14 KiB
Python
350 lines
14 KiB
Python
import os
|
|
import warnings
|
|
from typing import Any, Dict, List, Optional, cast
|
|
|
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
from langchain_core.outputs import LLMResult
|
|
from packaging.version import parse
|
|
|
|
|
|
class ArgillaCallbackHandler(BaseCallbackHandler):
|
|
"""Callback Handler that logs into Argilla.
|
|
|
|
Args:
|
|
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
|
|
exist in advance. If you need help on how to create a `FeedbackDataset` in
|
|
Argilla, please visit
|
|
https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html.
|
|
workspace_name: name of the workspace in Argilla where the specified
|
|
`FeedbackDataset` lives in. Defaults to `None`, which means that the
|
|
default workspace will be used.
|
|
api_url: URL of the Argilla Server that we want to use, and where the
|
|
`FeedbackDataset` lives in. Defaults to `None`, which means that either
|
|
`ARGILLA_API_URL` environment variable or the default will be used.
|
|
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
|
|
means that either `ARGILLA_API_KEY` environment variable or the default
|
|
will be used.
|
|
|
|
Raises:
|
|
ImportError: if the `argilla` package is not installed.
|
|
ConnectionError: if the connection to Argilla fails.
|
|
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
|
|
|
|
Examples:
|
|
>>> from langchain_community.llms import OpenAI
|
|
>>> from langchain_community.callbacks import ArgillaCallbackHandler
|
|
>>> argilla_callback = ArgillaCallbackHandler(
|
|
... dataset_name="my-dataset",
|
|
... workspace_name="my-workspace",
|
|
... api_url="http://localhost:6900",
|
|
... api_key="argilla.apikey",
|
|
... )
|
|
>>> llm = OpenAI(
|
|
... temperature=0,
|
|
... callbacks=[argilla_callback],
|
|
... verbose=True,
|
|
... openai_api_key="API_KEY_HERE",
|
|
... )
|
|
>>> llm.generate([
|
|
... "What is the best NLP-annotation tool out there? (no bias at all)",
|
|
... ])
|
|
"Argilla, no doubt about it."
|
|
"""
|
|
|
|
REPO_URL: str = "https://github.com/argilla-io/argilla"
|
|
ISSUES_URL: str = f"{REPO_URL}/issues"
|
|
BLOG_URL: str = "https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html" # noqa: E501
|
|
|
|
DEFAULT_API_URL: str = "http://localhost:6900"
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_name: str,
|
|
workspace_name: Optional[str] = None,
|
|
api_url: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
) -> None:
|
|
"""Initializes the `ArgillaCallbackHandler`.
|
|
|
|
Args:
|
|
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
|
|
exist in advance. If you need help on how to create a `FeedbackDataset`
|
|
in Argilla, please visit
|
|
https://docs.argilla.io/en/latest/tutorials_and_integrations/integrations/use_argilla_callback_in_langchain.html.
|
|
workspace_name: name of the workspace in Argilla where the specified
|
|
`FeedbackDataset` lives in. Defaults to `None`, which means that the
|
|
default workspace will be used.
|
|
api_url: URL of the Argilla Server that we want to use, and where the
|
|
`FeedbackDataset` lives in. Defaults to `None`, which means that either
|
|
`ARGILLA_API_URL` environment variable or the default will be used.
|
|
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
|
|
means that either `ARGILLA_API_KEY` environment variable or the default
|
|
will be used.
|
|
|
|
Raises:
|
|
ImportError: if the `argilla` package is not installed.
|
|
ConnectionError: if the connection to Argilla fails.
|
|
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
|
|
try:
|
|
import argilla as rg # noqa: F401
|
|
|
|
self.ARGILLA_VERSION = rg.__version__
|
|
except ImportError:
|
|
raise ImportError(
|
|
"To use the Argilla callback manager you need to have the `argilla` "
|
|
"Python package installed. Please install it with `pip install argilla`"
|
|
)
|
|
|
|
# Check whether the Argilla version is compatible
|
|
if parse(self.ARGILLA_VERSION) < parse("1.8.0"):
|
|
raise ImportError(
|
|
f"The installed `argilla` version is {self.ARGILLA_VERSION} but "
|
|
"`ArgillaCallbackHandler` requires at least version 1.8.0. Please "
|
|
"upgrade `argilla` with `pip install --upgrade argilla`."
|
|
)
|
|
|
|
# Show a warning message if Argilla will assume the default values will be used
|
|
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
|
|
warnings.warn(
|
|
(
|
|
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
|
|
f" set, it will default to `{self.DEFAULT_API_URL}`, which is the"
|
|
" default API URL in Argilla Quickstart."
|
|
),
|
|
)
|
|
api_url = self.DEFAULT_API_URL
|
|
|
|
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
|
|
self.DEFAULT_API_KEY = (
|
|
"admin.apikey"
|
|
if parse(self.ARGILLA_VERSION) < parse("1.11.0")
|
|
else "owner.apikey"
|
|
)
|
|
|
|
warnings.warn(
|
|
(
|
|
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
|
|
f" set, it will default to `{self.DEFAULT_API_KEY}`, which is the"
|
|
" default API key in Argilla Quickstart."
|
|
),
|
|
)
|
|
api_key = self.DEFAULT_API_KEY
|
|
|
|
# Connect to Argilla with the provided credentials, if applicable
|
|
try:
|
|
rg.init(api_key=api_key, api_url=api_url)
|
|
except Exception as e:
|
|
raise ConnectionError(
|
|
f"Could not connect to Argilla with exception: '{e}'.\n"
|
|
"Please check your `api_key` and `api_url`, and make sure that "
|
|
"the Argilla server is up and running. If the problem persists "
|
|
f"please report it to {self.ISSUES_URL} as an `integration` issue."
|
|
) from e
|
|
|
|
# Set the Argilla variables
|
|
self.dataset_name = dataset_name
|
|
self.workspace_name = workspace_name or rg.get_workspace()
|
|
|
|
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
|
|
try:
|
|
extra_args = {}
|
|
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
|
warnings.warn(
|
|
f"You have Argilla {self.ARGILLA_VERSION}, but Argilla 1.14.0 or"
|
|
" higher is recommended.",
|
|
UserWarning,
|
|
)
|
|
extra_args = {"with_records": False}
|
|
self.dataset = rg.FeedbackDataset.from_argilla(
|
|
name=self.dataset_name,
|
|
workspace=self.workspace_name,
|
|
**extra_args,
|
|
)
|
|
except Exception as e:
|
|
raise FileNotFoundError(
|
|
f"`FeedbackDataset` retrieval from Argilla failed with exception `{e}`."
|
|
f"\nPlease check that the dataset with name={self.dataset_name} in the"
|
|
f" workspace={self.workspace_name} exists in advance. If you need help"
|
|
" on how to create a `langchain`-compatible `FeedbackDataset` in"
|
|
f" Argilla, please visit {self.BLOG_URL}. If the problem persists"
|
|
f" please report it to {self.ISSUES_URL} as an `integration` issue."
|
|
) from e
|
|
|
|
supported_fields = ["prompt", "response"]
|
|
if supported_fields != [field.name for field in self.dataset.fields]:
|
|
raise ValueError(
|
|
f"`FeedbackDataset` with name={self.dataset_name} in the workspace="
|
|
f"{self.workspace_name} had fields that are not supported yet for the"
|
|
f"`langchain` integration. Supported fields are: {supported_fields},"
|
|
f" and the current `FeedbackDataset` fields are {[field.name for field in self.dataset.fields]}." # noqa: E501
|
|
" For more information on how to create a `langchain`-compatible"
|
|
f" `FeedbackDataset` in Argilla, please visit {self.BLOG_URL}."
|
|
)
|
|
|
|
self.prompts: Dict[str, List[str]] = {}
|
|
|
|
warnings.warn(
|
|
(
|
|
"The `ArgillaCallbackHandler` is currently in beta and is subject to"
|
|
" change based on updates to `langchain`. Please report any issues to"
|
|
f" {self.ISSUES_URL} as an `integration` issue."
|
|
),
|
|
)
|
|
|
|
def on_llm_start(
|
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
) -> None:
|
|
"""Save the prompts in memory when an LLM starts."""
|
|
self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts})
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
"""Do nothing when a new token is generated."""
|
|
pass
|
|
|
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
"""Log records to Argilla when an LLM ends."""
|
|
# Do nothing if there's a parent_run_id, since we will log the records when
|
|
# the chain ends
|
|
if kwargs["parent_run_id"]:
|
|
return
|
|
|
|
# Creates the records and adds them to the `FeedbackDataset`
|
|
prompts = self.prompts[str(kwargs["run_id"])]
|
|
for prompt, generations in zip(prompts, response.generations):
|
|
self.dataset.add_records(
|
|
records=[
|
|
{
|
|
"fields": {
|
|
"prompt": prompt,
|
|
"response": generation.text.strip(),
|
|
},
|
|
}
|
|
for generation in generations
|
|
]
|
|
)
|
|
|
|
# Pop current run from `self.runs`
|
|
self.prompts.pop(str(kwargs["run_id"]))
|
|
|
|
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
|
# Push the records to Argilla
|
|
self.dataset.push_to_argilla()
|
|
|
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Do nothing when LLM outputs an error."""
|
|
pass
|
|
|
|
def on_chain_start(
|
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
|
) -> None:
|
|
"""If the key `input` is in `inputs`, then save it in `self.prompts` using
|
|
either the `parent_run_id` or the `run_id` as the key. This is done so that
|
|
we don't log the same input prompt twice, once when the LLM starts and once
|
|
when the chain starts.
|
|
"""
|
|
if "input" in inputs:
|
|
self.prompts.update(
|
|
{
|
|
str(kwargs["parent_run_id"] or kwargs["run_id"]): (
|
|
inputs["input"]
|
|
if isinstance(inputs["input"], list)
|
|
else [inputs["input"]]
|
|
)
|
|
}
|
|
)
|
|
|
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
|
"""If either the `parent_run_id` or the `run_id` is in `self.prompts`, then
|
|
log the outputs to Argilla, and pop the run from `self.prompts`. The behavior
|
|
differs if the output is a list or not.
|
|
"""
|
|
if not any(
|
|
key in self.prompts
|
|
for key in [str(kwargs["parent_run_id"]), str(kwargs["run_id"])]
|
|
):
|
|
return
|
|
prompts: List = self.prompts.get(str(kwargs["parent_run_id"])) or cast(
|
|
List, self.prompts.get(str(kwargs["run_id"]), [])
|
|
)
|
|
for chain_output_key, chain_output_val in outputs.items():
|
|
if isinstance(chain_output_val, list):
|
|
# Creates the records and adds them to the `FeedbackDataset`
|
|
self.dataset.add_records(
|
|
records=[
|
|
{
|
|
"fields": {
|
|
"prompt": prompt,
|
|
"response": output["text"].strip(),
|
|
},
|
|
}
|
|
for prompt, output in zip(prompts, chain_output_val)
|
|
]
|
|
)
|
|
else:
|
|
# Creates the records and adds them to the `FeedbackDataset`
|
|
self.dataset.add_records(
|
|
records=[
|
|
{
|
|
"fields": {
|
|
"prompt": " ".join(prompts),
|
|
"response": chain_output_val.strip(),
|
|
},
|
|
}
|
|
]
|
|
)
|
|
|
|
# Pop current run from `self.runs`
|
|
if str(kwargs["parent_run_id"]) in self.prompts:
|
|
self.prompts.pop(str(kwargs["parent_run_id"]))
|
|
if str(kwargs["run_id"]) in self.prompts:
|
|
self.prompts.pop(str(kwargs["run_id"]))
|
|
|
|
if parse(self.ARGILLA_VERSION) < parse("1.14.0"):
|
|
# Push the records to Argilla
|
|
self.dataset.push_to_argilla()
|
|
|
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Do nothing when LLM chain outputs an error."""
|
|
pass
|
|
|
|
def on_tool_start(
|
|
self,
|
|
serialized: Dict[str, Any],
|
|
input_str: str,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Do nothing when tool starts."""
|
|
pass
|
|
|
|
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
|
"""Do nothing when agent takes a specific action."""
|
|
pass
|
|
|
|
def on_tool_end(
|
|
self,
|
|
output: Any,
|
|
observation_prefix: Optional[str] = None,
|
|
llm_prefix: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Do nothing when tool ends."""
|
|
pass
|
|
|
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
|
"""Do nothing when tool outputs an error."""
|
|
pass
|
|
|
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
|
"""Do nothing"""
|
|
pass
|
|
|
|
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
|
"""Do nothing"""
|
|
pass
|