Dev2049/add argilla callback (#5621)

Co-authored-by: Alvaro Bartolome <alvarobartt@gmail.com>
Co-authored-by: Daniel Vila Suero <daniel@argilla.io>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
searx_updates
Davis Chase 12 months ago committed by GitHub
parent 71a7c16ee0
commit d784401215
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,29 @@
# Argilla
![Argilla - Open-source data platform for LLMs](https://argilla.io/og.png)
>[Argilla](https://argilla.io/) is an open-source data curation platform for LLMs.
> Using Argilla, everyone can build robust language models through faster data curation
> using both human and machine feedback. We provide support for each step in the MLOps cycle,
> from data labeling to model monitoring.
## Installation and Setup
First, you'll need to install the `argilla` Python package as follows:
```bash
pip install argilla --upgrade
```
If you already have an Argilla Server running, then you're good to go; but if
you don't, follow the next steps to install it.
If you don't you can refer to [Argilla - 🚀 Quickstart](https://docs.argilla.io/en/latest/getting_started/quickstart.html#Running-Argilla-Quickstart) to deploy Argilla either on HuggingFace Spaces, locally, or on a server.
## Tracking
See a [usage example of `ArgillaCallbackHandler`](../modules/callbacks/examples/examples/argilla.ipynb).
```python
from langchain.callbacks import ArgillaCallbackHandler
```

@ -0,0 +1,423 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Argilla\n",
"\n",
"![Argilla - Open-source data platform for LLMs](https://argilla.io/og.png)\n",
"\n",
">[Argilla](https://argilla.io/) is an open-source data curation platform for LLMs.\n",
"> Using Argilla, everyone can build robust language models through faster data curation \n",
"> using both human and machine feedback. We provide support for each step in the MLOps cycle, \n",
"> from data labeling to model monitoring.\n",
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/hwchase17/langchain/blob/master/docs/modules/callbacks/examples/argilla.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"In this guide we will demonstrate how to track the inputs and reponses of your LLM to generate a dataset in Argilla, using the `ArgillaCallbackHandler`.\n",
"\n",
"It's useful to keep track of the inputs and outputs of your LLMs to generate datasets for future fine-tuning. This is especially useful when you're using a LLM to generate data for a specific task, such as question answering, summarization, or translation."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"## Installation and Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install argilla --upgrade\n",
"!pip install openai"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Getting API Credentials\n",
"\n",
"To get the Argilla API credentials, follow the next steps:\n",
"\n",
"1. Go to your Argilla UI.\n",
"2. Click on your profile picture and go to \"My settings\".\n",
"3. Then copy the API Key.\n",
"\n",
"In Argilla the API URL will be the same as the URL of your Argilla UI.\n",
"\n",
"To get the OpenAI API credentials, please visit https://platform.openai.com/account/api-keys"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"ARGILLA_API_URL\"] = \"...\"\n",
"os.environ[\"ARGILLA_API_KEY\"] = \"...\"\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"...\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup Argilla\n",
"\n",
"To use the `ArgillaCallbackHandler` we will need to create a new `FeedbackDataset` in Argilla to keep track of your LLM experiments. To do so, please use the following code:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import argilla as rg"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from packaging.version import parse as parse_version\n",
"\n",
"if parse_version(rg.__version__) < parse_version(\"1.8.0\"):\n",
" raise RuntimeError(\n",
" \"`FeedbackDataset` is only available in Argilla v1.8.0 or higher, please \"\n",
" \"upgrade `argilla` as `pip install argilla --upgrade`.\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = rg.FeedbackDataset(\n",
" fields=[\n",
" rg.TextField(name=\"prompt\"),\n",
" rg.TextField(name=\"response\"),\n",
" ],\n",
" questions=[\n",
" rg.RatingQuestion(\n",
" name=\"response-rating\",\n",
" description=\"How would you rate the quality of the response?\",\n",
" values=[1, 2, 3, 4, 5],\n",
" required=True,\n",
" ),\n",
" rg.TextQuestion(\n",
" name=\"response-feedback\",\n",
" description=\"What feedback do you have for the response?\",\n",
" required=False,\n",
" ),\n",
" ],\n",
" guidelines=\"You're asked to rate the quality of the response and provide feedback.\",\n",
")\n",
"\n",
"rg.init(\n",
" api_url=os.environ[\"ARGILLA_API_URL\"],\n",
" api_key=os.environ[\"ARGILLA_API_KEY\"],\n",
")\n",
"\n",
"dataset.push_to_argilla(\"langchain-dataset\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"> 📌 NOTE: at the moment, just the prompt-response pairs are supported as `FeedbackDataset.fields`, so the `ArgillaCallbackHandler` will just track the prompt i.e. the LLM input, and the response i.e. the LLM output."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tracking"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"To use the `ArgillaCallbackHandler` you can either use the following code, or just reproduce one of the examples presented in the following sections."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.callbacks import ArgillaCallbackHandler\n",
"\n",
"argilla_callback = ArgillaCallbackHandler(\n",
" dataset_name=\"langchain-dataset\",\n",
" api_url=os.environ[\"ARGILLA_API_URL\"],\n",
" api_key=os.environ[\"ARGILLA_API_KEY\"],\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Scenario 1: Tracking an LLM\n",
"\n",
"First, let's just run a single LLM a few times and capture the resulting prompt-response pairs in Argilla."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LLMResult(generations=[[Generation(text='\\n\\nQ: What did the fish say when he hit the wall? \\nA: Dam.', generation_info={'finish_reason': 'stop', 'logprobs': None})], [Generation(text='\\n\\nThe Moon \\n\\nThe moon is high in the midnight sky,\\nSparkling like a star above.\\nThe night so peaceful, so serene,\\nFilling up the air with love.\\n\\nEver changing and renewing,\\nA never-ending light of grace.\\nThe moon remains a constant view,\\nA reminder of lifes gentle pace.\\n\\nThrough time and space it guides us on,\\nA never-fading beacon of hope.\\nThe moon shines down on us all,\\nAs it continues to rise and elope.', generation_info={'finish_reason': 'stop', 'logprobs': None})], [Generation(text='\\n\\nQ. What did one magnet say to the other magnet?\\nA. \"I find you very attractive!\"', generation_info={'finish_reason': 'stop', 'logprobs': None})], [Generation(text=\"\\n\\nThe world is charged with the grandeur of God.\\nIt will flame out, like shining from shook foil;\\nIt gathers to a greatness, like the ooze of oil\\nCrushed. Why do men then now not reck his rod?\\n\\nGenerations have trod, have trod, have trod;\\nAnd all is seared with trade; bleared, smeared with toil;\\nAnd wears man's smudge and shares man's smell: the soil\\nIs bare now, nor can foot feel, being shod.\\n\\nAnd for all this, nature is never spent;\\nThere lives the dearest freshness deep down things;\\nAnd though the last lights off the black West went\\nOh, morning, at the brown brink eastward, springs —\\n\\nBecause the Holy Ghost over the bent\\nWorld broods with warm breast and with ah! bright wings.\\n\\n~Gerard Manley Hopkins\", generation_info={'finish_reason': 'stop', 'logprobs': None})], [Generation(text='\\n\\nQ: What did one ocean say to the other ocean?\\nA: Nothing, they just waved.', generation_info={'finish_reason': 'stop', 'logprobs': None})], [Generation(text=\"\\n\\nA poem for you\\n\\nOn a field of green\\n\\nThe sky so blue\\n\\nA gentle breeze, the sun above\\n\\nA beautiful world, for us to love\\n\\nLife is a journey, full of surprise\\n\\nFull of joy and full of surprise\\n\\nBe brave and take small steps\\n\\nThe future will be revealed with depth\\n\\nIn the morning, when dawn arrives\\n\\nA fresh start, no reason to hide\\n\\nSomewhere down the road, there's a heart that beats\\n\\nBelieve in yourself, you'll always succeed.\", generation_info={'finish_reason': 'stop', 'logprobs': None})]], llm_output={'token_usage': {'completion_tokens': 504, 'total_tokens': 528, 'prompt_tokens': 24}, 'model_name': 'text-davinci-003'})"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.callbacks import ArgillaCallbackHandler, StdOutCallbackHandler\n",
"from langchain.llms import OpenAI\n",
"\n",
"argilla_callback = ArgillaCallbackHandler(\n",
" dataset_name=\"langchain-dataset\",\n",
" api_url=os.environ[\"ARGILLA_API_URL\"],\n",
" api_key=os.environ[\"ARGILLA_API_KEY\"],\n",
")\n",
"callbacks = [StdOutCallbackHandler(), argilla_callback]\n",
"\n",
"llm = OpenAI(temperature=0.9, callbacks=callbacks)\n",
"llm.generate([\"Tell me a joke\", \"Tell me a poem\"] * 3)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"![Argilla UI with LangChain LLM input-response](https://docs.argilla.io/en/latest/_images/llm.png)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Scenario 2: Tracking an LLM in a chain\n",
"\n",
"Then we can create a chain using a prompt template, and then track the initial prompt and the final response in Argilla."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mYou are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
"Title: Documentary about Bigfoot in Paris\n",
"Playwright: This is a synopsis for the above play:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"[{'text': \"\\n\\nDocumentary about Bigfoot in Paris focuses on the story of a documentary filmmaker and their search for evidence of the legendary Bigfoot creature in the city of Paris. The play follows the filmmaker as they explore the city, meeting people from all walks of life who have had encounters with the mysterious creature. Through their conversations, the filmmaker unravels the story of Bigfoot and finds out the truth about the creature's presence in Paris. As the story progresses, the filmmaker learns more and more about the mysterious creature, as well as the different perspectives of the people living in the city, and what they think of the creature. In the end, the filmmaker's findings lead them to some surprising and heartwarming conclusions about the creature's existence and the importance it holds in the lives of the people in Paris.\"}]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.callbacks import ArgillaCallbackHandler, StdOutCallbackHandler\n",
"from langchain.llms import OpenAI\n",
"from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate\n",
"\n",
"argilla_callback = ArgillaCallbackHandler(\n",
" dataset_name=\"langchain-dataset\",\n",
" api_url=os.environ[\"ARGILLA_API_URL\"],\n",
" api_key=os.environ[\"ARGILLA_API_KEY\"],\n",
")\n",
"callbacks = [StdOutCallbackHandler(), argilla_callback]\n",
"llm = OpenAI(temperature=0.9, callbacks=callbacks)\n",
"\n",
"template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
"Title: {title}\n",
"Playwright: This is a synopsis for the above play:\"\"\"\n",
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n",
"\n",
"test_prompts = [{\"title\": \"Documentary about Bigfoot in Paris\"}]\n",
"synopsis_chain.apply(test_prompts)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"![Argilla UI with LangChain Chain input-response](https://docs.argilla.io/en/latest/_images/chain.png)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Scenario 3: Using an Agent with Tools\n",
"\n",
"Finally, as a more advanced workflow, you can create an agent that uses some tools. So that `ArgillaCallbackHandler` will keep track of the input and the output, but not about the intermediate steps/thoughts, so that given a prompt we log the original prompt and the final response to that given prompt."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"> Note that for this scenario we'll be using Google Search API (Serp API) so you will need to both install `google-search-results` as `pip install google-search-results`, and to set the Serp API Key as `os.environ[\"SERPAPI_API_KEY\"] = \"...\"` (you can find it at https://serpapi.com/dashboard), otherwise the example below won't work."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to answer a historical question\n",
"Action: Search\n",
"Action Input: \"who was the first president of the United States of America\" \u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mGeorge Washington\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m George Washington was the first president\n",
"Final Answer: George Washington was the first president of the United States of America.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'George Washington was the first president of the United States of America.'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.agents import AgentType, initialize_agent, load_tools\n",
"from langchain.callbacks import ArgillaCallbackHandler, StdOutCallbackHandler\n",
"from langchain.llms import OpenAI\n",
"\n",
"argilla_callback = ArgillaCallbackHandler(\n",
" dataset_name=\"langchain-dataset\",\n",
" api_url=os.environ[\"ARGILLA_API_URL\"],\n",
" api_key=os.environ[\"ARGILLA_API_KEY\"],\n",
")\n",
"callbacks = [StdOutCallbackHandler(), argilla_callback]\n",
"llm = OpenAI(temperature=0.9, callbacks=callbacks)\n",
"\n",
"tools = load_tools([\"serpapi\"], llm=llm, callbacks=callbacks)\n",
"agent = initialize_agent(\n",
" tools,\n",
" llm,\n",
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
" callbacks=callbacks,\n",
")\n",
"agent.run(\"Who was the first president of the United States of America?\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"![Argilla UI with LangChain Agent input-response](https://docs.argilla.io/en/latest/_images/agent.png)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"vscode": {
"interpreter": {
"hash": "a53ebf4a859167383b364e7e7521d0add3c2dbbdecce4edf676e8c4634ff3fbb"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -1,6 +1,7 @@
"""Callback handlers that allow listening to events in LangChain."""
from langchain.callbacks.aim_callback import AimCallbackHandler
from langchain.callbacks.argilla_callback import ArgillaCallbackHandler
from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
from langchain.callbacks.comet_ml_callback import CometCallbackHandler
from langchain.callbacks.human import HumanApprovalCallbackHandler
@ -17,6 +18,7 @@ from langchain.callbacks.wandb_callback import WandbCallbackHandler
from langchain.callbacks.whylabs_callback import WhyLabsCallbackHandler
__all__ = [
"ArgillaCallbackHandler",
"OpenAICallbackHandler",
"StdOutCallbackHandler",
"AimCallbackHandler",

@ -0,0 +1,316 @@
import os
import warnings
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
class ArgillaCallbackHandler(BaseCallbackHandler):
"""Callback Handler that logs into Argilla.
Args:
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
exist in advance. If you need help on how to create a `FeedbackDataset` in
Argilla, please visit
https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
workspace_name: name of the workspace in Argilla where the specified
`FeedbackDataset` lives in. Defaults to `None`, which means that the
default workspace will be used.
api_url: URL of the Argilla Server that we want to use, and where the
`FeedbackDataset` lives in. Defaults to `None`, which means that either
`ARGILLA_API_URL` environment variable or the default http://localhost:6900
will be used.
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
means that either `ARGILLA_API_KEY` environment variable or the default
`argilla.apikey` will be used.
Raises:
ImportError: if the `argilla` package is not installed.
ConnectionError: if the connection to Argilla fails.
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
Examples:
>>> from langchain.llms import OpenAI
>>> from langchain.callbacks import ArgillaCallbackHandler
>>> argilla_callback = ArgillaCallbackHandler(
... dataset_name="my-dataset",
... workspace_name="my-workspace",
... api_url="http://localhost:6900",
... api_key="argilla.apikey",
... )
>>> llm = OpenAI(
... temperature=0,
... callbacks=[argilla_callback],
... verbose=True,
... openai_api_key="API_KEY_HERE",
... )
>>> llm.generate([
... "What is the best NLP-annotation tool out there? (no bias at all)",
... ])
"Argilla, no doubt about it."
"""
def __init__(
self,
dataset_name: str,
workspace_name: Optional[str] = None,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> None:
"""Initializes the `ArgillaCallbackHandler`.
Args:
dataset_name: name of the `FeedbackDataset` in Argilla. Note that it must
exist in advance. If you need help on how to create a `FeedbackDataset`
in Argilla, please visit
https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html.
workspace_name: name of the workspace in Argilla where the specified
`FeedbackDataset` lives in. Defaults to `None`, which means that the
default workspace will be used.
api_url: URL of the Argilla Server that we want to use, and where the
`FeedbackDataset` lives in. Defaults to `None`, which means that either
`ARGILLA_API_URL` environment variable or the default
http://localhost:6900 will be used.
api_key: API Key to connect to the Argilla Server. Defaults to `None`, which
means that either `ARGILLA_API_KEY` environment variable or the default
`argilla.apikey` will be used.
Raises:
ImportError: if the `argilla` package is not installed.
ConnectionError: if the connection to Argilla fails.
FileNotFoundError: if the `FeedbackDataset` retrieval from Argilla fails.
"""
super().__init__()
# Import Argilla (not via `import_argilla` to keep hints in IDEs)
try:
import argilla as rg # noqa: F401
except ImportError:
raise ImportError(
"To use the Argilla callback manager you need to have the `argilla` "
"Python package installed. Please install it with `pip install argilla`"
)
# Show a warning message if Argilla will assume the default values will be used
if api_url is None and os.getenv("ARGILLA_API_URL") is None:
warnings.warn(
(
"Since `api_url` is None, and the env var `ARGILLA_API_URL` is not"
" set, it will default to `http://localhost:6900`."
),
)
if api_key is None and os.getenv("ARGILLA_API_KEY") is None:
warnings.warn(
(
"Since `api_key` is None, and the env var `ARGILLA_API_KEY` is not"
" set, it will default to `argilla.apikey`."
),
)
# Connect to Argilla with the provided credentials, if applicable
try:
rg.init(
api_key=api_key,
api_url=api_url,
)
except Exception as e:
raise ConnectionError(
f"Could not connect to Argilla with exception: '{e}'.\n"
"Please check your `api_key` and `api_url`, and make sure that "
"the Argilla server is up and running. If the problem persists "
"please report it to https://github.com/argilla-io/argilla/issues "
"with the label `langchain`."
) from e
# Set the Argilla variables
self.dataset_name = dataset_name
self.workspace_name = workspace_name or rg.get_workspace()
# Retrieve the `FeedbackDataset` from Argilla (without existing records)
try:
self.dataset = rg.FeedbackDataset.from_argilla(
name=self.dataset_name,
workspace=self.workspace_name,
with_records=False,
)
except Exception as e:
raise FileNotFoundError(
"`FeedbackDataset` retrieval from Argilla failed with exception:"
f" '{e}'.\nPlease check that the dataset with"
f" name={self.dataset_name} in the"
f" workspace={self.workspace_name} exists in advance. If you need help"
" on how to create a `langchain`-compatible `FeedbackDataset` in"
" Argilla, please visit"
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501
" If the problem persists please report it to"
" https://github.com/argilla-io/argilla/issues with the label"
" `langchain`."
) from e
supported_fields = ["prompt", "response"]
if supported_fields != [field.name for field in self.dataset.fields]:
raise ValueError(
f"`FeedbackDataset` with name={self.dataset_name} in the"
f" workspace={self.workspace_name} "
"had fields that are not supported yet for the `langchain` integration."
" Supported fields are: "
f"{supported_fields}, and the current `FeedbackDataset` fields are"
f" {[field.name for field in self.dataset.fields]}. "
"For more information on how to create a `langchain`-compatible"
" `FeedbackDataset` in Argilla, please visit"
" https://docs.argilla.io/en/latest/guides/llms/practical_guides/use_argilla_callback_in_langchain.html." # noqa: E501
)
self.prompts: Dict[str, List[str]] = {}
warnings.warn(
(
"The `ArgillaCallbackHandler` is currently in beta and is subject to "
"change based on updates to `langchain`. Please report any issues to "
"https://github.com/argilla-io/argilla/issues with the tag `langchain`."
),
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Save the prompts in memory when an LLM starts."""
self.prompts.update({str(kwargs["parent_run_id"] or kwargs["run_id"]): prompts})
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing when a new token is generated."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Log records to Argilla when an LLM ends."""
# Do nothing if there's a parent_run_id, since we will log the records when
# the chain ends
if kwargs["parent_run_id"]:
return
# Creates the records and adds them to the `FeedbackDataset`
prompts = self.prompts[str(kwargs["run_id"])]
for prompt, generations in zip(prompts, response.generations):
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": generation.text.strip(),
},
}
for generation in generations
]
)
# Push the records to Argilla
self.dataset.push_to_argilla()
# Pop current run from `self.runs`
self.prompts.pop(str(kwargs["run_id"]))
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM outputs an error."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Do nothing when LLM chain starts."""
if "input" in inputs:
self.prompts.update(
{
str(kwargs["parent_run_id"] or kwargs["run_id"]): (
inputs["input"]
if isinstance(inputs["input"], list)
else [inputs["input"]]
)
}
)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Do nothing when LLM chain ends."""
prompts = self.prompts[str(kwargs["parent_run_id"] or kwargs["run_id"])]
if "outputs" in outputs:
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": prompt,
"response": output["text"].strip(),
},
}
for prompt, output in zip(prompts, outputs["outputs"])
]
)
elif "output" in outputs:
# Creates the records and adds them to the `FeedbackDataset`
self.dataset.add_records(
records=[
{
"fields": {
"prompt": " ".join(prompts),
"response": outputs["output"].strip(),
},
}
]
)
else:
raise ValueError(
"The `outputs` dictionary did not contain the expected keys `outputs` "
"or `output`."
)
# Push the records to Argilla
self.dataset.push_to_argilla()
# Pop current run from `self.runs`
self.prompts.pop(str(kwargs["parent_run_id"] or kwargs["run_id"]))
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing when tool starts."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing when agent takes a specific action."""
pass
def on_tool_end(
self,
output: str,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Do nothing when tool ends."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when tool outputs an error."""
pass
def on_text(self, text: str, **kwargs: Any) -> None:
"""Do nothing"""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Do nothing"""
pass
Loading…
Cancel
Save