community: extend Predibase integration to support fine-tuned LLM adapters (#19979)

- [x] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [x] **PR message**: ***Delete this entire checklist*** and replace
with
- **Description:** Langchain-Predibase integration was failing, because
it was not current with the Predibase SDK; in addition, Predibase
integration tests were instantiating the Langchain Community `Predibase`
class with one required argument (`model`) missing. This change updates
the Predibase SDK usage and fixes the integration tests.
    - **Twitter handle:** `@alexsherstinsky`


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/20177/head
Alex Sherstinsky 3 months ago committed by GitHub
parent a27d88f12a
commit 5f563e040a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -50,7 +50,24 @@
"from langchain_community.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
" model=\"mistral-7b\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import Predibase\n",
"\n",
"# With an adapter, fine-tuned on the specified model\n",
"model = Predibase(\n",
" model=\"mistral-7b\",\n",
" adapter_id=\"predibase/e2e_nlg\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
")"
]
},
@ -66,19 +83,43 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"source": [
"## Chain Call Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from langchain_community.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"mistral-7b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# With an adapter, fine-tuned on the specified model\n",
"llm = Predibase(\n",
" model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
" model=\"mistral-7b\",\n",
" adapter_id=\"predibase/e2e_nlg\",\n",
" predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n",
")"
]
},
@ -169,7 +210,11 @@
"from langchain_community.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"my-finetuned-LLM\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
" model=\"my-base-LLM\",\n",
" adapter_id=\"my-finetuned-adapter-id\",\n",
" predibase_api_key=os.environ.get(\n",
" \"PREDIBASE_API_TOKEN\"\n",
" ), # Adapter argument is optional.\n",
")\n",
"# replace my-finetuned-LLM with the name of your model in Predibase"
]

@ -17,7 +17,21 @@ os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
from langchain_community.llms import Predibase
model = Predibase(model = 'vicuna-13b', predibase_api_key=os.environ.get('PREDIBASE_API_TOKEN'))
model = Predibase(model="mistral-7b"", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
response = model("Can you recommend me a nice dry wine?")
print(response)
```
Predibase also supports adapters that are fine-tuned on the base model given by the `model` argument:
```python
import os
os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
from langchain_community.llms import Predibase
model = Predibase(model="mistral-7b"", adapter_id="predibase/e2e_nlg", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"))
response = model("Can you recommend me a nice dry wine?")
print(response)

@ -10,10 +10,18 @@ class Predibase(LLM):
To use, you should have the ``predibase`` python package installed,
and have your Predibase API key.
The `model` parameter is the Predibase "serverless" base_model ID
(see https://docs.predibase.com/user-guide/inference/models for the catalog).
An optional `adapter_id` parameter is the HuggingFace ID of a fine-tuned LLM
adapter, whose base model is the `model` parameter; the fine-tuned adapter
must be compatible with its base model; otherwise, an error is raised.
"""
model: str
predibase_api_key: SecretStr
adapter_id: Optional[str] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
default_options_for_generation: dict = Field(
{
@ -38,7 +46,10 @@ class Predibase(LLM):
from predibase import PredibaseClient
from predibase.pql import get_session
from predibase.pql.api import Session
from predibase.resource.llm.interface import LLMDeployment
from predibase.resource.llm.interface import (
HuggingFaceLLM,
LLMDeployment,
)
from predibase.resource.llm.response import GeneratedResponse
session: Session = get_session(
@ -55,15 +66,23 @@ class Predibase(LLM):
except ValueError as e:
raise ValueError("Your API key is not correct. Please try again") from e
options: Dict[str, Union[str, float]] = (
kwargs or self.default_options_for_generation
self.model_kwargs or self.default_options_for_generation
)
base_llm_deployment: LLMDeployment = pc.LLM(
uri=f"pb://deployments/{self.model}"
)
result: GeneratedResponse = base_llm_deployment.generate(
prompt=prompt,
options=options,
)
result: GeneratedResponse
if self.adapter_id:
adapter_model: HuggingFaceLLM = pc.LLM(uri=f"hf://{self.adapter_id}")
result = base_llm_deployment.with_adapter(model=adapter_model).generate(
prompt=prompt,
options=options,
)
else:
result = base_llm_deployment.generate(
prompt=prompt,
options=options,
)
return result.response
@property

@ -17,3 +17,20 @@ def test_api_key_masked_when_passed_via_constructor(
captured = capsys.readouterr()
assert captured.out == "**********"
def test_specifying_adapter_id_argument() -> None:
llm = Predibase(model="my_llm", predibase_api_key="secret-api-key")
assert not llm.adapter_id
llm = Predibase(
model="my_llm", predibase_api_key="secret-api-key", adapter_id="my-hf-adapter"
)
assert llm.adapter_id == "my-hf-adapter"
llm = Predibase(
model="my_llm",
adapter_id="my-other-hf-adapter",
predibase_api_key="secret-api-key",
)
assert llm.adapter_id == "my-other-hf-adapter"
Loading…
Cancel
Save