mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add LabelStudio integration (#8880)
This PR introduces [Label Studio](https://labelstud.io/) integration with LangChain via `LabelStudioCallbackHandler`: - sending data to the Label Studio instance - labeling dataset for supervised LLM finetuning - rating model responses - tracking and displaying chat history - support for custom data labeling workflow ### Example ``` chat_llm = ChatOpenAI(callbacks=[LabelStudioCallbackHandler(mode="chat")]) chat_llm([ SystemMessage(content="Always use emojis in your responses."), HumanMessage(content="Hey AI, how's your day going?"), AIMessage(content="🤖 I don't have feelings, but I'm running smoothly! How can I help you today?"), HumanMessage(content="I'm feeling a bit down. Any advice?"), AIMessage(content="🤗 I'm sorry to hear that. Remember, it's okay to seek help or talk to someone if you need to. 💬"), HumanMessage(content="Can you tell me a joke to lighten the mood?"), AIMessage(content="Of course! 🎭 Why did the scarecrow win an award? Because he was outstanding in his field! 🌾"), HumanMessage(content="Haha, that was a good one! Thanks for cheering me up."), AIMessage(content="Always here to help! 😊 If you need anything else, just let me know."), HumanMessage(content="Will do! By the way, can you recommend a good movie?"), ]) ``` <img width="906" alt="image" src="https://github.com/langchain-ai/langchain/assets/6087484/0a1cf559-0bd3-4250-ad96-6e71dbb1d2f3"> ### Dependencies - [label-studio](https://pypi.org/project/label-studio/) - [label-studio-sdk](https://pypi.org/project/label-studio-sdk/) https://twitter.com/labelstudiohq --------- Co-authored-by: nik <nik@heartex.net>
This commit is contained in:
parent
8cb2594562
commit
16af5f8690
382
docs/extras/integrations/callbacks/labelstudio.ipynb
Normal file
382
docs/extras/integrations/callbacks/labelstudio.ipynb
Normal file
@ -0,0 +1,382 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# Label Studio\n",
|
||||
"\n",
|
||||
"<div>\n",
|
||||
"<img src=\"https://labelstudio-pub.s3.amazonaws.com/lc/open-source-data-labeling-platform.png\" width=\"400\"/>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
"Label Studio is an open-source data labeling platform that provides LangChain with flexibility when it comes to labeling data for fine-tuning large language models (LLMs). It also enables the preparation of custom training data and the collection and evaluation of responses through human feedback.\n",
|
||||
"\n",
|
||||
"In this guide, you will learn how to connect a LangChain pipeline to Label Studio to:\n",
|
||||
"\n",
|
||||
"- Aggregate all input prompts, conversations, and responses in a single LabelStudio project. This consolidates all the data in one place for easier labeling and analysis.\n",
|
||||
"- Refine prompts and responses to create a dataset for supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) scenarios. The labeled data can be used to further train the LLM to improve its performance.\n",
|
||||
"- Evaluate model responses through human feedback. LabelStudio provides an interface for humans to review and provide feedback on model responses, allowing evaluation and iteration."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Installation and setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"First install latest versions of Label Studio and Label Studio API client:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -U label-studio label-studio-sdk openai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"Next, run `label-studio` on the command line to start the local LabelStudio instance at `http://localhost:8080`. See the [Label Studio installation guide](https://labelstud.io/guide/install) for more options."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"You'll need a token to make API calls.\n",
|
||||
"\n",
|
||||
"Open your LabelStudio instance in your browser, go to `Account & Settings > Access Token` and copy the key.\n",
|
||||
"\n",
|
||||
"Set environment variables with your LabelStudio URL, API key and OpenAI API key:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ['LABEL_STUDIO_URL'] = '<YOUR-LABEL-STUDIO-URL>' # e.g. http://localhost:8080\n",
|
||||
"os.environ['LABEL_STUDIO_API_KEY'] = '<YOUR-LABEL-STUDIO-API-KEY>'\n",
|
||||
"os.environ['OPENAI_API_KEY'] = '<YOUR-OPENAI-API-KEY>'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Collecting LLMs prompts and responses"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The data used for labeling is stored in projects within Label Studio. Every project is identified by an XML configuration that details the specifications for input and output data. \n",
|
||||
"\n",
|
||||
"Create a project that takes human input in text format and outputs an editable LLM response in a text area:\n",
|
||||
"\n",
|
||||
"```xml\n",
|
||||
"<View>\n",
|
||||
"<Style>\n",
|
||||
" .prompt-box {\n",
|
||||
" background-color: white;\n",
|
||||
" border-radius: 10px;\n",
|
||||
" box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);\n",
|
||||
" padding: 20px;\n",
|
||||
" }\n",
|
||||
"</Style>\n",
|
||||
"<View className=\"root\">\n",
|
||||
" <View className=\"prompt-box\">\n",
|
||||
" <Text name=\"prompt\" value=\"$prompt\"/>\n",
|
||||
" </View>\n",
|
||||
" <TextArea name=\"response\" toName=\"prompt\"\n",
|
||||
" maxSubmissions=\"1\" editable=\"true\"\n",
|
||||
" required=\"true\"/>\n",
|
||||
"</View>\n",
|
||||
"<Header value=\"Rate the response:\"/>\n",
|
||||
"<Rating name=\"rating\" toName=\"prompt\"/>\n",
|
||||
"</View>\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"1. To create a project in Label Studio, click on the \"Create\" button. \n",
|
||||
"2. Enter a name for your project in the \"Project Name\" field, such as `My Project`.\n",
|
||||
"3. Navigate to `Labeling Setup > Custom Template` and paste the XML configuration provided above."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"You can collect input LLM prompts and output responses in a LabelStudio project, connecting it via `LabelStudioCallbackHandler`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.callbacks import LabelStudioCallbackHandler\n",
|
||||
"\n",
|
||||
"llm = OpenAI(\n",
|
||||
" temperature=0,\n",
|
||||
" callbacks=[\n",
|
||||
" LabelStudioCallbackHandler(\n",
|
||||
" project_name=\"My Project\"\n",
|
||||
" )]\n",
|
||||
")\n",
|
||||
"print(llm(\"Tell me a joke\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"In the Label Studio, open `My Project`. You will see the prompts, responses, and metadata like the model name. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Collecting Chat model Dialogues"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can also track and display full chat dialogues in LabelStudio, with the ability to rate and modify the last response:\n",
|
||||
"\n",
|
||||
"1. Open Label Studio and click on the \"Create\" button.\n",
|
||||
"2. Enter a name for your project in the \"Project Name\" field, such as `New Project with Chat`.\n",
|
||||
"3. Navigate to Labeling Setup > Custom Template and paste the following XML configuration:\n",
|
||||
"\n",
|
||||
"```xml\n",
|
||||
"<View>\n",
|
||||
"<View className=\"root\">\n",
|
||||
" <Paragraphs name=\"dialogue\"\n",
|
||||
" value=\"$prompt\"\n",
|
||||
" layout=\"dialogue\"\n",
|
||||
" textKey=\"content\"\n",
|
||||
" nameKey=\"role\"\n",
|
||||
" granularity=\"sentence\"/>\n",
|
||||
" <Header value=\"Final response:\"/>\n",
|
||||
" <TextArea name=\"response\" toName=\"dialogue\"\n",
|
||||
" maxSubmissions=\"1\" editable=\"true\"\n",
|
||||
" required=\"true\"/>\n",
|
||||
"</View>\n",
|
||||
"<Header value=\"Rate the response:\"/>\n",
|
||||
"<Rating name=\"rating\" toName=\"dialogue\"/>\n",
|
||||
"</View>\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.schema import HumanMessage, SystemMessage\n",
|
||||
"from langchain.callbacks import LabelStudioCallbackHandler\n",
|
||||
"\n",
|
||||
"chat_llm = ChatOpenAI(callbacks=[\n",
|
||||
" LabelStudioCallbackHandler(\n",
|
||||
" mode=\"chat\",\n",
|
||||
" project_name=\"New Project with Chat\",\n",
|
||||
" )\n",
|
||||
"])\n",
|
||||
"llm_results = chat_llm([\n",
|
||||
" SystemMessage(content=\"Always use a lot of emojis\"),\n",
|
||||
" HumanMessage(content=\"Tell me a joke\")\n",
|
||||
"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In Label Studio, open \"New Project with Chat\". Click on a created task to view dialog history and edit/annotate responses."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Custom Labeling Configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"You can modify the default labeling configuration in LabelStudio to add more target labels like response sentiment, relevance, and many [other types annotator's feedback](https://labelstud.io/tags/).\n",
|
||||
"\n",
|
||||
"New labeling configuration can be added from UI: go to `Settings > Labeling Interface` and set up a custom configuration with additional tags like `Choices` for sentiment or `Rating` for relevance. Keep in mind that [`TextArea` tag](https://labelstud.io/tags/textarea) should be presented in any configuration to display the LLM responses.\n",
|
||||
"\n",
|
||||
"Alternatively, you can specify the labeling configuration on the initial call before project creation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ls = LabelStudioCallbackHandler(project_config='''\n",
|
||||
"<View>\n",
|
||||
"<Text name=\"prompt\" value=\"$prompt\"/>\n",
|
||||
"<TextArea name=\"response\" toName=\"prompt\"/>\n",
|
||||
"<TextArea name=\"user_feedback\" toName=\"prompt\"/>\n",
|
||||
"<Rating name=\"rating\" toName=\"prompt\"/>\n",
|
||||
"<Choices name=\"sentiment\" toName=\"prompt\">\n",
|
||||
" <Choice value=\"Positive\"/>\n",
|
||||
" <Choice value=\"Negative\"/>\n",
|
||||
"</Choices>\n",
|
||||
"</View>\n",
|
||||
"''')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that if the project doesn't exist, it will be created with the specified labeling configuration."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Other parameters"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"The `LabelStudioCallbackHandler` accepts several optional parameters:\n",
|
||||
"\n",
|
||||
"- **api_key** - Label Studio API key. Overrides environmental variable `LABEL_STUDIO_API_KEY`.\n",
|
||||
"- **url** - Label Studio URL. Overrides `LABEL_STUDIO_URL`, default `http://localhost:8080`.\n",
|
||||
"- **project_id** - Existing Label Studio project ID. Overrides `LABEL_STUDIO_PROJECT_ID`. Stores data in this project.\n",
|
||||
"- **project_name** - Project name if project ID not specified. Creates a new project. Default is `\"LangChain-%Y-%m-%d\"` formatted with the current date.\n",
|
||||
"- **project_config** - [custom labeling configuration](#custom-labeling-configuration)\n",
|
||||
"- **mode**: use this shortcut to create target configuration from scratch:\n",
|
||||
" - `\"prompt\"` - Single prompt, single response. Default.\n",
|
||||
" - `\"chat\"` - Multi-turn chat mode.\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "labelops",
|
||||
"language": "python",
|
||||
"name": "labelops"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
@ -18,6 +18,7 @@ from langchain.callbacks.file import FileCallbackHandler
|
||||
from langchain.callbacks.flyte_callback import FlyteCallbackHandler
|
||||
from langchain.callbacks.human import HumanApprovalCallbackHandler
|
||||
from langchain.callbacks.infino_callback import InfinoCallbackHandler
|
||||
from langchain.callbacks.labelstudio_callback import LabelStudioCallbackHandler
|
||||
from langchain.callbacks.manager import (
|
||||
get_openai_callback,
|
||||
tracing_enabled,
|
||||
@ -68,4 +69,5 @@ __all__ = [
|
||||
"wandb_tracing_enabled",
|
||||
"FlyteCallbackHandler",
|
||||
"SageMakerCallbackHandler",
|
||||
"LabelStudioCallbackHandler",
|
||||
]
|
||||
|
392
libs/langchain/langchain/callbacks/labelstudio_callback.py
Normal file
392
libs/langchain/langchain/callbacks/labelstudio_callback.py
Normal file
@ -0,0 +1,392 @@
|
||||
import os
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
Generation,
|
||||
LLMResult,
|
||||
)
|
||||
|
||||
|
||||
class LabelStudioMode(Enum):
|
||||
PROMPT = "prompt"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
def get_default_label_configs(
|
||||
mode: Union[str, LabelStudioMode]
|
||||
) -> Tuple[str, LabelStudioMode]:
|
||||
_default_label_configs = {
|
||||
LabelStudioMode.PROMPT.value: """
|
||||
<View>
|
||||
<Style>
|
||||
.prompt-box {
|
||||
background-color: white;
|
||||
border-radius: 10px;
|
||||
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
|
||||
padding: 20px;
|
||||
}
|
||||
</Style>
|
||||
<View className="root">
|
||||
<View className="prompt-box">
|
||||
<Text name="prompt" value="$prompt"/>
|
||||
</View>
|
||||
<TextArea name="response" toName="prompt"
|
||||
maxSubmissions="1" editable="true"
|
||||
required="true"/>
|
||||
</View>
|
||||
<Header value="Rate the response:"/>
|
||||
<Rating name="rating" toName="prompt"/>
|
||||
</View>""",
|
||||
LabelStudioMode.CHAT.value: """
|
||||
<View>
|
||||
<View className="root">
|
||||
<Paragraphs name="dialogue"
|
||||
value="$prompt"
|
||||
layout="dialogue"
|
||||
textKey="content"
|
||||
nameKey="role"
|
||||
granularity="sentence"/>
|
||||
<Header value="Final response:"/>
|
||||
<TextArea name="response" toName="dialogue"
|
||||
maxSubmissions="1" editable="true"
|
||||
required="true"/>
|
||||
</View>
|
||||
<Header value="Rate the response:"/>
|
||||
<Rating name="rating" toName="dialogue"/>
|
||||
</View>""",
|
||||
}
|
||||
|
||||
if isinstance(mode, str):
|
||||
mode = LabelStudioMode(mode)
|
||||
|
||||
return _default_label_configs[mode.value], mode
|
||||
|
||||
|
||||
class LabelStudioCallbackHandler(BaseCallbackHandler):
|
||||
"""Label Studio callback handler.
|
||||
Provides the ability to send predictions to Label Studio
|
||||
for human evaluation, feedback and annotation.
|
||||
|
||||
Parameters:
|
||||
api_key: Label Studio API key
|
||||
url: Label Studio URL
|
||||
project_id: Label Studio project ID
|
||||
project_name: Label Studio project name
|
||||
project_config: Label Studio project config (XML)
|
||||
mode: Label Studio mode ("prompt" or "chat")
|
||||
|
||||
Examples:
|
||||
>>> from langchain.llms import OpenAI
|
||||
>>> from langchain.callbacks import LabelStudioCallbackHandler
|
||||
>>> handler = LabelStudioCallbackHandler(
|
||||
... api_key='<your_key_here>',
|
||||
... url='http://localhost:8080',
|
||||
... project_name='LangChain-%Y-%m-%d',
|
||||
... mode='prompt'
|
||||
... )
|
||||
>>> llm = OpenAI(callbacks=[handler])
|
||||
>>> llm.predict('Tell me a story about a dog.')
|
||||
"""
|
||||
|
||||
DEFAULT_PROJECT_NAME = "LangChain-%Y-%m-%d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
project_id: Optional[int] = None,
|
||||
project_name: str = DEFAULT_PROJECT_NAME,
|
||||
project_config: Optional[str] = None,
|
||||
mode: Union[str, LabelStudioMode] = LabelStudioMode.PROMPT,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Import LabelStudio SDK
|
||||
try:
|
||||
import label_studio_sdk as ls
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f"You're using {self.__class__.__name__} in your code,"
|
||||
f" but you don't have the LabelStudio SDK "
|
||||
f"Python package installed or upgraded to the latest version. "
|
||||
f"Please run `pip install -U label-studio-sdk`"
|
||||
f" before using this callback."
|
||||
)
|
||||
|
||||
# Check if Label Studio API key is provided
|
||||
if not api_key:
|
||||
if os.getenv("LABEL_STUDIO_API_KEY"):
|
||||
api_key = str(os.getenv("LABEL_STUDIO_API_KEY"))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"You're using {self.__class__.__name__} in your code,"
|
||||
f" Label Studio API key is not provided. "
|
||||
f"Please provide Label Studio API key: "
|
||||
f"go to the Label Studio instance, navigate to "
|
||||
f"Account & Settings -> Access Token and copy the key. "
|
||||
f"Use the key as a parameter for the callback: "
|
||||
f"{self.__class__.__name__}"
|
||||
f"(label_studio_api_key='<your_key_here>', ...) or "
|
||||
f"set the environment variable LABEL_STUDIO_API_KEY=<your_key_here>"
|
||||
)
|
||||
self.api_key = api_key
|
||||
|
||||
if not url:
|
||||
if os.getenv("LABEL_STUDIO_URL"):
|
||||
url = os.getenv("LABEL_STUDIO_URL")
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Label Studio URL is not provided, "
|
||||
f"using default URL: {ls.LABEL_STUDIO_DEFAULT_URL}"
|
||||
f"If you want to provide your own URL, use the parameter: "
|
||||
f"{self.__class__.__name__}"
|
||||
f"(label_studio_url='<your_url_here>', ...) "
|
||||
f"or set the environment variable LABEL_STUDIO_URL=<your_url_here>"
|
||||
)
|
||||
url = ls.LABEL_STUDIO_DEFAULT_URL
|
||||
self.url = url
|
||||
|
||||
# Maps run_id to prompts
|
||||
self.payload: Dict[str, Dict] = {}
|
||||
|
||||
self.ls_client = ls.Client(url=self.url, api_key=self.api_key)
|
||||
self.project_name = project_name
|
||||
if project_config:
|
||||
self.project_config = project_config
|
||||
self.mode = None
|
||||
else:
|
||||
self.project_config, self.mode = get_default_label_configs(mode)
|
||||
|
||||
self.project_id = project_id or os.getenv("LABEL_STUDIO_PROJECT_ID")
|
||||
if self.project_id is not None:
|
||||
self.ls_project = self.ls_client.get_project(int(self.project_id))
|
||||
else:
|
||||
project_title = datetime.today().strftime(self.project_name)
|
||||
existing_projects = self.ls_client.get_projects(title=project_title)
|
||||
if existing_projects:
|
||||
self.ls_project = existing_projects[0]
|
||||
self.project_id = self.ls_project.id
|
||||
else:
|
||||
self.ls_project = self.ls_client.create_project(
|
||||
title=project_title, label_config=self.project_config
|
||||
)
|
||||
self.project_id = self.ls_project.id
|
||||
self.parsed_label_config = self.ls_project.parsed_label_config
|
||||
|
||||
# Find the first TextArea tag
|
||||
# "from_name", "to_name", "value" will be used to create predictions
|
||||
self.from_name, self.to_name, self.value, self.input_type = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
for tag_name, tag_info in self.parsed_label_config.items():
|
||||
if tag_info["type"] == "TextArea":
|
||||
self.from_name = tag_name
|
||||
self.to_name = tag_info["to_name"][0]
|
||||
self.value = tag_info["inputs"][0]["value"]
|
||||
self.input_type = tag_info["inputs"][0]["type"]
|
||||
break
|
||||
if not self.from_name:
|
||||
error_message = (
|
||||
f'Label Studio project "{self.project_name}" '
|
||||
f"does not have a TextArea tag. "
|
||||
f"Please add a TextArea tag to the project."
|
||||
)
|
||||
if self.mode == LabelStudioMode.PROMPT:
|
||||
error_message += (
|
||||
"\nHINT: go to project Settings -> "
|
||||
"Labeling Interface -> Browse Templates"
|
||||
' and select "Generative AI -> '
|
||||
'Supervised Language Model Fine-tuning" template.'
|
||||
)
|
||||
else:
|
||||
error_message += (
|
||||
"\nHINT: go to project Settings -> "
|
||||
"Labeling Interface -> Browse Templates"
|
||||
" and check available templates under "
|
||||
'"Generative AI" section.'
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
|
||||
def add_prompts_generations(
|
||||
self, run_id: str, generations: List[List[Generation]]
|
||||
) -> None:
|
||||
# Create tasks in Label Studio
|
||||
tasks = []
|
||||
prompts = self.payload[run_id]["prompts"]
|
||||
model_version = (
|
||||
self.payload[run_id]["kwargs"]
|
||||
.get("invocation_params", {})
|
||||
.get("model_name")
|
||||
)
|
||||
for prompt, generation in zip(prompts, generations):
|
||||
tasks.append(
|
||||
{
|
||||
"data": {
|
||||
self.value: prompt,
|
||||
"run_id": run_id,
|
||||
},
|
||||
"predictions": [
|
||||
{
|
||||
"result": [
|
||||
{
|
||||
"from_name": self.from_name,
|
||||
"to_name": self.to_name,
|
||||
"type": "textarea",
|
||||
"value": {"text": [g.text for g in generation]},
|
||||
}
|
||||
],
|
||||
"model_version": model_version,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
self.ls_project.import_tasks(tasks)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Save the prompts in memory when an LLM starts."""
|
||||
if self.input_type != "Text":
|
||||
raise ValueError(
|
||||
f'\nLabel Studio project "{self.project_name}" '
|
||||
f"has an input type <{self.input_type}>. "
|
||||
f'To make it work with the mode="chat", '
|
||||
f"the input type should be <Text>.\n"
|
||||
f"Read more here https://labelstud.io/tags/text"
|
||||
)
|
||||
run_id = str(kwargs["run_id"])
|
||||
self.payload[run_id] = {"prompts": prompts, "kwargs": kwargs}
|
||||
|
||||
def _get_message_role(self, message: BaseMessage) -> str:
|
||||
"""Get the role of the message."""
|
||||
if isinstance(message, ChatMessage):
|
||||
return message.role
|
||||
else:
|
||||
return message.__class__.__name__
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Save the prompts in memory when an LLM starts."""
|
||||
if self.input_type != "Paragraphs":
|
||||
raise ValueError(
|
||||
f'\nLabel Studio project "{self.project_name}" '
|
||||
f"has an input type <{self.input_type}>. "
|
||||
f'To make it work with the mode="chat", '
|
||||
f"the input type should be <Paragraphs>.\n"
|
||||
f"Read more here https://labelstud.io/tags/paragraphs"
|
||||
)
|
||||
|
||||
prompts = []
|
||||
for message_list in messages:
|
||||
dialog = []
|
||||
for message in message_list:
|
||||
dialog.append(
|
||||
{
|
||||
"role": self._get_message_role(message),
|
||||
"content": message.content,
|
||||
}
|
||||
)
|
||||
prompts.append(dialog)
|
||||
self.payload[str(run_id)] = {
|
||||
"prompts": prompts,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
"run_id": run_id,
|
||||
"parent_run_id": parent_run_id,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing when a new token is generated."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Create a new Label Studio task for each prompt and generation."""
|
||||
run_id = str(kwargs["run_id"])
|
||||
|
||||
# Submit results to Label Studio
|
||||
self.add_prompts_generations(run_id, response.generations)
|
||||
|
||||
# Pop current run from `self.runs`
|
||||
self.payload.pop(run_id)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM outputs an error."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when LLM chain outputs an error."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool starts."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Do nothing when agent takes a specific action."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing when tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing when tool outputs an error."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Do nothing"""
|
||||
pass
|
Loading…
Reference in New Issue
Block a user