Add `TrainableLLM` (#11721)

- **Description:** Add `TrainableLLM` for those LLM support fine-tuning
  - **Tag maintainer:** @hwchase17

This PR add training methods to `GradientLLM`

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11714/head^2
Yang, Bo 12 months ago committed by GitHub
parent 63e516c2b0
commit 9e1e0f54d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,12 +50,7 @@ from langchain.load.dump import dumpd
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.pydantic_v1 import Field, root_validator, validator
from langchain.schema import (
Generation,
LLMResult,
PromptValue,
RunInfo,
)
from langchain.schema import Generation, LLMResult, PromptValue, RunInfo
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
from langchain.schema.output import GenerationChunk

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict, Union
import aiohttp
import requests
@ -13,6 +13,10 @@ from langchain.pydantic_v1 import Extra, root_validator
from langchain.utils import get_from_dict_or_env
class TrainResult(TypedDict):
loss: float
class GradientLLM(LLM):
"""Gradient.ai LLM Endpoints.
@ -125,6 +129,51 @@ class GradientLLM(LLM):
"""Return type of llm."""
return "gradient"
def _kwargs_post_fine_tune_request(
self, inputs: Sequence[str], kwargs: Mapping[str, Any]
) -> Mapping[str, Any]:
"""Build the kwargs for the Post request, used by sync
Args:
prompt (str): prompt used in query
kwargs (dict): model kwargs in payload
Returns:
Dict[str, Union[str,dict]]: _description_
"""
_model_kwargs = self.model_kwargs or {}
_params = {**_model_kwargs, **kwargs}
multipliers = _params.get("multipliers", None)
return dict(
url=f"{self.gradient_api_url}/models/{self.model_id}/fine-tune",
headers={
"authorization": f"Bearer {self.gradient_access_token}",
"x-gradient-workspace-id": f"{self.gradient_workspace_id}",
"accept": "application/json",
"content-type": "application/json",
},
json=dict(
samples=tuple(
{
"inputs": input,
}
for input in inputs
)
if multipliers is None
else tuple(
{
"inputs": input,
"fineTuningParameters": {
"multiplier": multiplier,
},
}
for input, multiplier in zip(inputs, multipliers)
),
),
)
def _kwargs_post_request(
self, prompt: str, kwargs: Mapping[str, Any]
) -> Mapping[str, Any]:
@ -234,3 +283,60 @@ class GradientLLM(LLM):
text = enforce_stop_tokens(text, stop)
return text
def train_unsupervised(
self,
inputs: Sequence[str],
**kwargs: Any,
) -> TrainResult:
try:
response = requests.post(
**self._kwargs_post_fine_tune_request(inputs, kwargs)
)
if response.status_code != 200:
raise Exception(
f"Gradient returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
except requests.exceptions.RequestException as e:
raise Exception(f"RequestException while calling Gradient Endpoint: {e}")
response_json = response.json()
loss = response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
return TrainResult(loss=loss)
async def atrain_unsupervised(
self,
inputs: Sequence[str],
**kwargs: Any,
) -> TrainResult:
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(
**self._kwargs_post_fine_tune_request(inputs, kwargs)
) as response:
if response.status != 200:
raise Exception(
f"Gradient returned an unexpected response with status "
f"{response.status}: {response.text}"
)
response_json = await response.json()
loss = (
response_json["sumLoss"]
/ response_json["numberOfTrainableTokens"]
)
else:
async with self.aiosession.post(
**self._kwargs_post_fine_tune_request(inputs, kwargs)
) as response:
if response.status != 200:
raise Exception(
f"Gradient returned an unexpected response with status "
f"{response.status}: {response.text}"
)
response_json = await response.json()
loss = (
response_json["sumLoss"] / response_json["numberOfTrainableTokens"]
)
return TrainResult(loss=loss)

Loading…
Cancel
Save