mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add DeepInfra LLM support (#1232)
DeepInfra is an Inference-as-a-Service provider. Add a simple wrapper using HTTPS requests.
This commit is contained in:
parent
b7765a95a0
commit
8e3cd3e0dd
17
docs/ecosystem/deepinfra.md
Normal file
17
docs/ecosystem/deepinfra.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# DeepInfra
|
||||||
|
|
||||||
|
This page covers how to use the DeepInfra ecosystem within LangChain.
|
||||||
|
It is broken into two parts: installation and setup, and then references to specific DeepInfra wrappers.
|
||||||
|
|
||||||
|
## Installation and Setup
|
||||||
|
- Get your DeepInfra api key from this link [here](https://deepinfra.com/).
|
||||||
|
- Get an DeepInfra api key and set it as an environment variable (`DEEPINFRA_API_TOKEN`)
|
||||||
|
|
||||||
|
## Wrappers
|
||||||
|
|
||||||
|
### LLM
|
||||||
|
|
||||||
|
There exists an DeepInfra LLM wrapper, which you can access with
|
||||||
|
```python
|
||||||
|
from langchain.llms import DeepInfra
|
||||||
|
```
|
@ -27,6 +27,8 @@ The examples here are all "how-to" guides for how to integrate with various LLM
|
|||||||
|
|
||||||
`Anthropic <./integrations/anthropic_example.html>`_: Covers how to use Anthropic models with Langchain.
|
`Anthropic <./integrations/anthropic_example.html>`_: Covers how to use Anthropic models with Langchain.
|
||||||
|
|
||||||
|
`DeepInfra <./integrations/deepinfra_example.html>`_: Covers how to utilize the DeepInfra wrapper.
|
||||||
|
|
||||||
`Self-Hosted Models (via Runhouse) <./integrations/self_hosted_examples.html>`_: Covers how to run models on existing or on-demand remote compute with Langchain.
|
`Self-Hosted Models (via Runhouse) <./integrations/self_hosted_examples.html>`_: Covers how to run models on existing or on-demand remote compute with Langchain.
|
||||||
|
|
||||||
|
|
||||||
|
141
docs/modules/llms/integrations/deepinfra_example.ipynb
Normal file
141
docs/modules/llms/integrations/deepinfra_example.ipynb
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# DeepInfra LLM Example\n",
|
||||||
|
"This notebook goes over how to use Langchain with [DeepInfra](https://deepinfra.com)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Imports"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"from langchain.llms import DeepInfra\n",
|
||||||
|
"from langchain import PromptTemplate, LLMChain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Set the Environment API Key\n",
|
||||||
|
"Make sure to get your API key from DeepInfra. You are given a 1 hour free of serverless GPU compute to test different models.\n",
|
||||||
|
"You can print your token with `deepctl auth token`"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"os.environ[\"DEEPINFRA_API_TOKEN\"] = \"YOUR_KEY_HERE\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create the DeepInfra instance\n",
|
||||||
|
"Make sure to deploy your model first via `deepctl deploy create -m google/flat-t5-xl` (for example)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = DeepInfra(model_id=\"DEPLOYED MODEL ID\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Create a Prompt Template\n",
|
||||||
|
"We will create a prompt template for Question and Answer."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"template = \"\"\"Question: {question}\n",
|
||||||
|
"\n",
|
||||||
|
"Answer: Let's think step by step.\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Initiate the LLMChain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Run the LLMChain\n",
|
||||||
|
"Provide a question and run the LLMChain."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"question = \"What NFL team won the Super Bowl in 2015?\"\n",
|
||||||
|
"\n",
|
||||||
|
"llm_chain.run(question)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3.9.12 ('palm')",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python",
|
||||||
|
"version": "3.9.12"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4,
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -7,6 +7,7 @@ from langchain.llms.anthropic import Anthropic
|
|||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.llms.cerebriumai import CerebriumAI
|
from langchain.llms.cerebriumai import CerebriumAI
|
||||||
from langchain.llms.cohere import Cohere
|
from langchain.llms.cohere import Cohere
|
||||||
|
from langchain.llms.deepinfra import DeepInfra
|
||||||
from langchain.llms.forefrontai import ForefrontAI
|
from langchain.llms.forefrontai import ForefrontAI
|
||||||
from langchain.llms.gooseai import GooseAI
|
from langchain.llms.gooseai import GooseAI
|
||||||
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||||
@ -24,6 +25,7 @@ __all__ = [
|
|||||||
"AlephAlpha",
|
"AlephAlpha",
|
||||||
"CerebriumAI",
|
"CerebriumAI",
|
||||||
"Cohere",
|
"Cohere",
|
||||||
|
"DeepInfra",
|
||||||
"ForefrontAI",
|
"ForefrontAI",
|
||||||
"GooseAI",
|
"GooseAI",
|
||||||
"NLPCloud",
|
"NLPCloud",
|
||||||
@ -45,6 +47,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
|||||||
"anthropic": Anthropic,
|
"anthropic": Anthropic,
|
||||||
"cerebriumai": CerebriumAI,
|
"cerebriumai": CerebriumAI,
|
||||||
"cohere": Cohere,
|
"cohere": Cohere,
|
||||||
|
"deepinfra": DeepInfra,
|
||||||
"forefrontai": ForefrontAI,
|
"forefrontai": ForefrontAI,
|
||||||
"gooseai": GooseAI,
|
"gooseai": GooseAI,
|
||||||
"huggingface_hub": HuggingFaceHub,
|
"huggingface_hub": HuggingFaceHub,
|
||||||
|
97
langchain/llms/deepinfra.py
Normal file
97
langchain/llms/deepinfra.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
"""Wrapper around DeepInfra APIs."""
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
DEFAULT_MODEL_ID = "google/flan-t5-xl"
|
||||||
|
|
||||||
|
|
||||||
|
class DeepInfra(LLM, BaseModel):
|
||||||
|
"""Wrapper around DeepInfra deployed models.
|
||||||
|
|
||||||
|
To use, you should have the ``requests`` python package installed, and the
|
||||||
|
environment variable ``DEEPINFRA_API_TOKEN`` set with your API token, or pass
|
||||||
|
it as a named parameter to the constructor.
|
||||||
|
|
||||||
|
Only supports `text-generation` and `text2text-generation` for now.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import DeepInfra
|
||||||
|
di = DeepInfra(model_id="google/flan-t5-xl",
|
||||||
|
deepinfra_api_token="my-api-key")
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_id: str = DEFAULT_MODEL_ID
|
||||||
|
model_kwargs: Optional[dict] = None
|
||||||
|
|
||||||
|
deepinfra_api_token: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
deepinfra_api_token = get_from_dict_or_env(
|
||||||
|
values, "deepinfra_api_token", "DEEPINFRA_API_TOKEN"
|
||||||
|
)
|
||||||
|
values["deepinfra_api_token"] = deepinfra_api_token
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
**{"model_id": self.model_id},
|
||||||
|
**{"model_kwargs": self.model_kwargs},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "deepinfra"
|
||||||
|
|
||||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
"""Call out to DeepInfra's inference API endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = di("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
_model_kwargs = self.model_kwargs or {}
|
||||||
|
|
||||||
|
res = requests.post(
|
||||||
|
f"https://api.deepinfra.com/v1/inference/{self.model_id}",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"bearer {self.deepinfra_api_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={"input": prompt, **_model_kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
if res.status_code != 200:
|
||||||
|
raise ValueError("Error raised by inference API")
|
||||||
|
text = res.json()[0]["generated_text"]
|
||||||
|
|
||||||
|
if stop is not None:
|
||||||
|
# I believe this is required since the stop tokens
|
||||||
|
# are not enforced by the model parameters
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
Loading…
Reference in New Issue
Block a user