Add DeepInfra embeddings integration with tests and examples, better exception handling for Deep Infra LLM (#5854)

#### Who can review?

Tag maintainers/contributors who might be interested:
  @hwchase17 - project lead
  - @agola11

---------

Co-authored-by: Yessen Kanapin <yessen@deepinfra.com>
searx_updates
Yessen Kanapin 12 months ago committed by GitHub
parent 4d8cda1c3b
commit c66755b661
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,133 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DeepInfra\n",
"\n",
"[DeepInfra](https://deepinfra.com/?utm_source=langchain) is a serverless inference as a service that provides access to a [variety of LLMs](https://deepinfra.com/models?utm_source=langchain) and [embeddings models](https://deepinfra.com/models?type=embeddings&utm_source=langchain). This notebook goes over how to use LangChain with DeepInfra for text embeddings."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
" ········\n"
]
}
],
"source": [
"# sign up for an account: https://deepinfra.com/login?utm_source=langchain\n",
"\n",
"from getpass import getpass\n",
"\n",
"DEEPINFRA_API_TOKEN = getpass()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"DEEPINFRA_API_TOKEN\"] = DEEPINFRA_API_TOKEN"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import DeepInfraEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"embeddings = DeepInfraEmbeddings(\n",
" model_id=\"sentence-transformers/clip-ViT-B-32\",\n",
" query_instruction=\"\",\n",
" embed_instruction=\"\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"docs = [\"Dog is not a cat\",\n",
" \"Beta is the second letter of Greek alphabet\"]\n",
"document_result = embeddings.embed_documents(docs)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"query = \"What is the first letter of Greek alphabet\"\n",
"query_result = embeddings.embed_query(query)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cosine similarity between \"Dog is not a cat\" and query: 0.7489097144129355\n",
"Cosine similarity between \"Beta is the second letter of Greek alphabet\" and query: 0.9519380640702013\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"query_numpy = np.array(query_result)\n",
"for doc_res, doc in zip(document_result, docs):\n",
" document_numpy = np.array(doc_res)\n",
" similarity = np.dot(query_numpy, document_numpy) / (np.linalg.norm(query_numpy)*np.linalg.norm(document_numpy))\n",
" print(f\"Cosine similarity between \\\"{doc}\\\" and query: {similarity}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -8,6 +8,7 @@ from langchain.embeddings.aleph_alpha import (
) )
from langchain.embeddings.bedrock import BedrockEmbeddings from langchain.embeddings.bedrock import BedrockEmbeddings
from langchain.embeddings.cohere import CohereEmbeddings from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.deepinfra import DeepInfraEmbeddings
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
from langchain.embeddings.fake import FakeEmbeddings from langchain.embeddings.fake import FakeEmbeddings
from langchain.embeddings.google_palm import GooglePalmEmbeddings from langchain.embeddings.google_palm import GooglePalmEmbeddings
@ -58,6 +59,7 @@ __all__ = [
"MiniMaxEmbeddings", "MiniMaxEmbeddings",
"VertexAIEmbeddings", "VertexAIEmbeddings",
"BedrockEmbeddings", "BedrockEmbeddings",
"DeepInfraEmbeddings",
] ]

@ -0,0 +1,129 @@
from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32"
class DeepInfraEmbeddings(BaseModel, Embeddings):
"""Wrapper around Deep Infra's embedding inference service.
To use, you should have the
environment variable ``DEEPINFRA_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
There are multiple embeddings models available,
see https://deepinfra.com/models?type=embeddings.
Example:
.. code-block:: python
from langchain.embeddings import DeepInfraEmbeddings
deepinfra_emb = DeepInfraEmbeddings(
model_id="sentence-transformers/clip-ViT-B-32",
deepinfra_api_token="my-api-key"
)
r1 = deepinfra_emb.embed_documents(
[
"Alpha is the first letter of Greek alphabet",
"Beta is the second letter of Greek alphabet",
]
)
r2 = deepinfra_emb.embed_query(
"What is the second letter of Greek alphabet"
)
"""
model_id: str = DEFAULT_MODEL_ID
"""Embeddings model to use."""
normalize: bool = False
"""whether to normalize the computed embeddings"""
embed_instruction: str = "passage: "
"""Instruction used to embed documents."""
query_instruction: str = "query: "
"""Instruction used to embed the query."""
model_kwargs: Optional[dict] = None
"""Other model keyword args"""
deepinfra_api_token: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
deepinfra_api_token = get_from_dict_or_env(
values, "deepinfra_api_token", "DEEPINFRA_API_TOKEN"
)
values["deepinfra_api_token"] = deepinfra_api_token
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"model_id": self.model_id}
def _embed(self, input: List[str]) -> List[List[float]]:
_model_kwargs = self.model_kwargs or {}
# HTTP headers for authorization
headers = {
"Authorization": f"bearer {self.deepinfra_api_token}",
"Content-Type": "application/json",
}
# send request
try:
res = requests.post(
f"https://api.deepinfra.com/v1/inference/{self.model_id}",
headers=headers,
json={"inputs": input, "normalize": self.normalize, **_model_kwargs},
)
except requests.exceptions.RequestException as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
if res.status_code != 200:
raise ValueError(
"Error raised by inference API HTTP code: %s, %s"
% (res.status_code, res.text)
)
try:
t = res.json()
embeddings = t["embeddings"]
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised by inference API: {e}.\nResponse: {res.text}"
)
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed documents using a Deep Infra deployed embedding model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
instruction_pairs = [f"{self.query_instruction}{text}" for text in texts]
embeddings = self._embed(instruction_pairs)
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed a query using a Deep Infra deployed embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
instruction_pair = f"{self.query_instruction}{text}"
embedding = self._embed([instruction_pair])[0]
return embedding

@ -82,20 +82,33 @@ class DeepInfra(LLM):
response = di("Tell me a joke.") response = di("Tell me a joke.")
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
# HTTP headers for authorization
headers = {
"Authorization": f"bearer {self.deepinfra_api_token}",
"Content-Type": "application/json",
}
res = requests.post( try:
f"https://api.deepinfra.com/v1/inference/{self.model_id}", res = requests.post(
headers={ f"https://api.deepinfra.com/v1/inference/{self.model_id}",
"Authorization": f"bearer {self.deepinfra_api_token}", headers=headers,
"Content-Type": "application/json", json={"input": prompt, **_model_kwargs},
}, )
json={"input": prompt, **_model_kwargs}, except requests.exceptions.RequestException as e:
) raise ValueError(f"Error raised by inference endpoint: {e}")
if res.status_code != 200: if res.status_code != 200:
raise ValueError("Error raised by inference API") raise ValueError(
t = res.json() "Error raised by inference API HTTP code: %s, %s"
text = t["results"][0]["generated_text"] % (res.status_code, res.text)
)
try:
t = res.json()
text = t["results"][0]["generated_text"]
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised by inference API: {e}.\nResponse: {res.text}"
)
if stop is not None: if stop is not None:
# I believe this is required since the stop tokens # I believe this is required since the stop tokens

@ -0,0 +1,19 @@
"""Test DeepInfra API wrapper."""
from langchain.embeddings import DeepInfraEmbeddings
def test_deepinfra_call() -> None:
"""Test valid call to DeepInfra."""
deepinfra_emb = DeepInfraEmbeddings(model_id="sentence-transformers/clip-ViT-B-32")
r1 = deepinfra_emb.embed_documents(
[
"Alpha is the first letter of Greek alphabet",
"Beta is the second letter of Greek alphabet",
]
)
assert len(r1) == 2
assert len(r1[0]) == 512
assert len(r1[1]) == 512
r2 = deepinfra_emb.embed_query("What is the third letter of Greek alphabet")
assert len(r2) == 512

@ -0,0 +1,10 @@
"""Test DeepInfra API wrapper."""
from langchain.llms.deepinfra import DeepInfra
def test_deepinfra_call() -> None:
"""Test valid call to DeepInfra."""
llm = DeepInfra(model_id="google/flan-t5-small")
output = llm("What is 2 + 2?")
assert isinstance(output, str)
Loading…
Cancel
Save