Improvements to llm/deepinfra (#10846)

- replace `requests` package with `langchain.requests`
- add `_acall` support
- add `_stream` and `_astream`
- freshen up the documentation a bit
- update vendor doc
pull/12219/head
Iskren Ivov Chernev 9 months ago committed by GitHub
parent f09f82541b
commit d5d7ba582a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,30 +6,7 @@
"source": [
"# DeepInfra\n",
"\n",
"`DeepInfra` provides [several LLMs](https://deepinfra.com/models).\n",
"\n",
"This notebook goes over how to use Langchain with [DeepInfra](https://deepinfra.com)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"from langchain.llms import DeepInfra\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain"
"[DeepInfra](https://deepinfra.com/?utm_source=langchain) is a serverless inference as a service that provides access to a [variety of LLMs](https://deepinfra.com/models?utm_source=langchain) and [embeddings models](https://deepinfra.com/models?type=embeddings&utm_source=langchain). This notebook goes over how to use LangChain with DeepInfra for language models."
]
},
{
@ -45,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"metadata": {
"tags": []
},
@ -68,12 +45,14 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"DEEPINFRA_API_TOKEN\"] = DEEPINFRA_API_TOKEN"
]
},
@ -87,11 +66,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"llm = DeepInfra(model_id=\"databricks/dolly-v2-12b\")\n",
"from langchain.llms import DeepInfra\n",
"\n",
"llm = DeepInfra(model_id=\"meta-llama/Llama-2-70b-chat-hf\")\n",
"llm.model_kwargs = {\n",
" \"temperature\": 0.7,\n",
" \"repetition_penalty\": 1.2,\n",
@ -100,6 +81,51 @@
"}"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'This is a question that has puzzled many people'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# run inferences directly via wrapper\n",
"llm(\"Who let the dogs out?\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" Will\n",
" Smith\n",
"."
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# run streaming inference\n",
"for chunk in llm.stream(\"Who let the dogs out?\"):\n",
" print(chunk)"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -110,10 +136,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
@ -130,10 +158,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
@ -147,16 +177,16 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"Penguins live in the Southern hemisphere.\\nThe North pole is located in the Northern hemisphere.\\nSo, first you need to turn the penguin South.\\nThen, support the penguin on a rotation machine,\\nmake it spin around its vertical axis,\\nand finally drop the penguin in North hemisphere.\\nNow, you have a penguin in the north pole!\\n\\nStill didn't understand?\\nWell, you're a failure as a teacher.\""
"\"Penguins are found in Antarctica and the surrounding islands, which are located at the southernmost tip of the planet. The North Pole is located at the northernmost tip of the planet, and it would be a long journey for penguins to get there. In fact, penguins don't have the ability to fly or migrate over such long distances. So, no, penguins cannot reach the North Pole. \""
]
},
"execution_count": 8,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@ -166,6 +196,13 @@
"\n",
"llm_chain.run(question)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -184,7 +221,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.5"
},
"vscode": {
"interpreter": {

@ -10,16 +10,27 @@ It is broken into two parts: installation and setup, and then references to spec
## Available Models
DeepInfra provides a range of Open Source LLMs ready for deployment.
You can list supported models [here](https://deepinfra.com/models?type=text-generation).
You can list supported models for
[text-generation](https://deepinfra.com/models?type=text-generation) and
[embeddings](https://deepinfra.com/models?type=embeddings).
google/flan\* models can be viewed [here](https://deepinfra.com/models?type=text2text-generation).
You can view a list of request and response parameters [here](https://deepinfra.com/databricks/dolly-v2-12b#API)
You can view a [list of request and response parameters](https://deepinfra.com/meta-llama/Llama-2-70b-chat-hf/api).
## Wrappers
### LLM
There exists an DeepInfra LLM wrapper, which you can access with
```python
from langchain.llms import DeepInfra
```
### Embeddings
There is also an DeepInfra Embeddings wrapper, you can access with
```python
from langchain.embeddings import DeepInfraEmbeddings
```

@ -1,11 +1,15 @@
from typing import Any, Dict, List, Mapping, Optional
import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
import requests
import aiohttp
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM, GenerationChunk
from langchain.pydantic_v1 import Extra, root_validator
from langchain.utilities.requests import Requests
from langchain.utils import get_from_dict_or_env
DEFAULT_MODEL_ID = "google/flan-t5-xl"
@ -14,9 +18,9 @@ DEFAULT_MODEL_ID = "google/flan-t5-xl"
class DeepInfra(LLM):
"""DeepInfra models.
To use, you should have the ``requests`` python package installed, and the
environment variable ``DEEPINFRA_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
To use, you should have the environment variable ``DEEPINFRA_API_TOKEN``
set with your API token, or pass it as a named parameter to the
constructor.
Only supports `text-generation` and `text2text-generation` for now.
@ -29,7 +33,7 @@ class DeepInfra(LLM):
"""
model_id: str = DEFAULT_MODEL_ID
model_kwargs: Optional[dict] = None
model_kwargs: Optional[Dict] = None
deepinfra_api_token: Optional[str] = None
@ -60,6 +64,35 @@ class DeepInfra(LLM):
"""Return type of llm."""
return "deepinfra"
def _url(self) -> str:
return f"https://api.deepinfra.com/v1/inference/{self.model_id}"
def _headers(self) -> Dict:
return {
"Authorization": f"bearer {self.deepinfra_api_token}",
"Content-Type": "application/json",
}
def _body(self, prompt: str, kwargs: Any) -> Dict:
model_kwargs = self.model_kwargs or {}
model_kwargs = {**model_kwargs, **kwargs}
return {
"input": prompt,
**model_kwargs,
}
def _handle_status(self, code: int, text: Any) -> None:
if code >= 500:
raise Exception(f"DeepInfra Server: Error {code}")
elif code >= 400:
raise ValueError(f"DeepInfra received an invalid payload: {text}")
elif code != 200:
raise Exception(
f"DeepInfra returned an unexpected response with status "
f"{code}: {text}"
)
def _call(
self,
prompt: str,
@ -81,38 +114,105 @@ class DeepInfra(LLM):
response = di("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
_model_kwargs = {**_model_kwargs, **kwargs}
# HTTP headers for authorization
headers = {
"Authorization": f"bearer {self.deepinfra_api_token}",
"Content-Type": "application/json",
}
try:
res = requests.post(
f"https://api.deepinfra.com/v1/inference/{self.model_id}",
headers=headers,
json={"input": prompt, **_model_kwargs},
)
except requests.exceptions.RequestException as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
request = Requests(headers=self._headers())
response = request.post(url=self._url(), data=self._body(prompt, kwargs))
if res.status_code != 200:
raise ValueError(
"Error raised by inference API HTTP code: %s, %s"
% (res.status_code, res.text)
)
try:
t = res.json()
text = t["results"][0]["generated_text"]
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised by inference API: {e}.\nResponse: {res.text}"
)
self._handle_status(response.status_code, response.text)
data = response.json()
return data["results"][0]["generated_text"]
if stop is not None:
# I believe this is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
request = Requests(headers=self._headers())
async with request.apost(
url=self._url(), data=self._body(prompt, kwargs)
) as response:
self._handle_status(response.status, response.text)
data = await response.json()
return data["results"][0]["generated_text"]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
request = Requests(headers=self._headers())
response = request.post(
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
)
self._handle_status(response.status_code, response.text)
for line in _parse_stream(response.iter_lines()):
chunk = _handle_sse_line(line)
if chunk:
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
request = Requests(headers=self._headers())
async with request.apost(
url=self._url(), data=self._body(prompt, {**kwargs, "stream": True})
) as response:
self._handle_status(response.status, response.text)
async for line in _parse_stream_async(response.content):
chunk = _handle_sse_line(line)
if chunk:
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text)
def _parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
for line in rbody:
_line = _parse_stream_helper(line)
if _line is not None:
yield _line
async def _parse_stream_async(rbody: aiohttp.StreamReader) -> AsyncIterator[str]:
async for line in rbody:
_line = _parse_stream_helper(line)
if _line is not None:
yield _line
def _parse_stream_helper(line: bytes) -> Optional[str]:
if line and line.startswith(b"data:"):
if line.startswith(b"data: "):
# SSE event may be valid when it contain whitespace
line = line[len(b"data: ") :]
else:
line = line[len(b"data:") :]
if line.strip() == b"[DONE]":
# return here will cause GeneratorExit exception in urllib3
# and it will close http connection with TCP Reset
return None
else:
return line.decode("utf-8")
return None
def _handle_sse_line(line: str) -> Optional[GenerationChunk]:
try:
obj = json.loads(line)
return GenerationChunk(
text=obj.get("token", {}).get("text"),
)
except Exception:
return None

@ -1,10 +1,36 @@
"""Test DeepInfra API wrapper."""
import pytest
from langchain.llms.deepinfra import DeepInfra
def test_deepinfra_call() -> None:
"""Test valid call to DeepInfra."""
llm = DeepInfra(model_id="google/flan-t5-small")
llm = DeepInfra(model_id="meta-llama/Llama-2-7b-chat-hf")
output = llm("What is 2 + 2?")
assert isinstance(output, str)
@pytest.mark.asyncio
async def test_deepinfra_acall() -> None:
llm = DeepInfra(model_id="meta-llama/Llama-2-7b-chat-hf")
output = await llm.apredict("What is 2 + 2?")
assert llm._llm_type == "deepinfra"
assert isinstance(output, str)
def test_deepinfra_stream() -> None:
llm = DeepInfra(model_id="meta-llama/Llama-2-7b-chat-hf")
num_chunks = 0
for chunk in llm.stream("[INST] Hello [/INST] "):
num_chunks += 1
assert num_chunks > 0
@pytest.mark.asyncio
async def test_deepinfra_astream() -> None:
llm = DeepInfra(model_id="meta-llama/Llama-2-7b-chat-hf")
num_chunks = 0
async for chunk in llm.astream("[INST] Hello [/INST] "):
num_chunks += 1
assert num_chunks > 0

Loading…
Cancel
Save