From c66755b6613567ebfe957e45cf710fc36d89684b Mon Sep 17 00:00:00 2001 From: Yessen Kanapin Date: Wed, 7 Jun 2023 19:14:30 -0700 Subject: [PATCH] Add DeepInfra embeddings integration with tests and examples, better exception handling for Deep Infra LLM (#5854) #### Who can review? Tag maintainers/contributors who might be interested: @hwchase17 - project lead - @agola11 --------- Co-authored-by: Yessen Kanapin --- .../text_embedding/examples/deepinfra.ipynb | 133 ++++++++++++++++++ langchain/embeddings/__init__.py | 2 + langchain/embeddings/deepinfra.py | 129 +++++++++++++++++ langchain/llms/deepinfra.py | 35 +++-- .../embeddings/test_deepinfra.py | 19 +++ .../integration_tests/llms/test_deepinfra.py | 10 ++ 6 files changed, 317 insertions(+), 11 deletions(-) create mode 100644 docs/modules/models/text_embedding/examples/deepinfra.ipynb create mode 100644 langchain/embeddings/deepinfra.py create mode 100644 tests/integration_tests/embeddings/test_deepinfra.py create mode 100644 tests/integration_tests/llms/test_deepinfra.py diff --git a/docs/modules/models/text_embedding/examples/deepinfra.ipynb b/docs/modules/models/text_embedding/examples/deepinfra.ipynb new file mode 100644 index 00000000..25b1972f --- /dev/null +++ b/docs/modules/models/text_embedding/examples/deepinfra.ipynb @@ -0,0 +1,133 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# DeepInfra\n", + "\n", + "[DeepInfra](https://deepinfra.com/?utm_source=langchain) is a serverless inference as a service that provides access to a [variety of LLMs](https://deepinfra.com/models?utm_source=langchain) and [embeddings models](https://deepinfra.com/models?type=embeddings&utm_source=langchain). This notebook goes over how to use LangChain with DeepInfra for text embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdin", + "output_type": "stream", + "text": [ + " ········\n" + ] + } + ], + "source": [ + "# sign up for an account: https://deepinfra.com/login?utm_source=langchain\n", + "\n", + "from getpass import getpass\n", + "\n", + "DEEPINFRA_API_TOKEN = getpass()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"DEEPINFRA_API_TOKEN\"] = DEEPINFRA_API_TOKEN" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings import DeepInfraEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = DeepInfraEmbeddings(\n", + " model_id=\"sentence-transformers/clip-ViT-B-32\",\n", + " query_instruction=\"\",\n", + " embed_instruction=\"\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "docs = [\"Dog is not a cat\",\n", + " \"Beta is the second letter of Greek alphabet\"]\n", + "document_result = embeddings.embed_documents(docs)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What is the first letter of Greek alphabet\"\n", + "query_result = embeddings.embed_query(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cosine similarity between \"Dog is not a cat\" and query: 0.7489097144129355\n", + "Cosine similarity between \"Beta is the second letter of Greek alphabet\" and query: 0.9519380640702013\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "query_numpy = np.array(query_result)\n", + "for doc_res, doc in zip(document_result, docs):\n", + " document_numpy = np.array(doc_res)\n", + " similarity = np.dot(query_numpy, document_numpy) / (np.linalg.norm(query_numpy)*np.linalg.norm(document_numpy))\n", + " print(f\"Cosine similarity between \\\"{doc}\\\" and query: {similarity}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index 6261b3dc..dafc3e64 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -8,6 +8,7 @@ from langchain.embeddings.aleph_alpha import ( ) from langchain.embeddings.bedrock import BedrockEmbeddings from langchain.embeddings.cohere import CohereEmbeddings +from langchain.embeddings.deepinfra import DeepInfraEmbeddings from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings from langchain.embeddings.fake import FakeEmbeddings from langchain.embeddings.google_palm import GooglePalmEmbeddings @@ -58,6 +59,7 @@ __all__ = [ "MiniMaxEmbeddings", "VertexAIEmbeddings", "BedrockEmbeddings", + "DeepInfraEmbeddings", ] diff --git a/langchain/embeddings/deepinfra.py b/langchain/embeddings/deepinfra.py new file mode 100644 index 00000000..9c66dda0 --- /dev/null +++ b/langchain/embeddings/deepinfra.py @@ -0,0 +1,129 @@ +from typing import Any, Dict, List, Mapping, Optional + +import requests +from pydantic import BaseModel, Extra, root_validator + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env + +DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32" + + +class DeepInfraEmbeddings(BaseModel, Embeddings): + """Wrapper around Deep Infra's embedding inference service. + + To use, you should have the + environment variable ``DEEPINFRA_API_TOKEN`` set with your API token, or pass + it as a named parameter to the constructor. + There are multiple embeddings models available, + see https://deepinfra.com/models?type=embeddings. + + Example: + .. code-block:: python + + from langchain.embeddings import DeepInfraEmbeddings + deepinfra_emb = DeepInfraEmbeddings( + model_id="sentence-transformers/clip-ViT-B-32", + deepinfra_api_token="my-api-key" + ) + r1 = deepinfra_emb.embed_documents( + [ + "Alpha is the first letter of Greek alphabet", + "Beta is the second letter of Greek alphabet", + ] + ) + r2 = deepinfra_emb.embed_query( + "What is the second letter of Greek alphabet" + ) + + """ + + model_id: str = DEFAULT_MODEL_ID + """Embeddings model to use.""" + normalize: bool = False + """whether to normalize the computed embeddings""" + embed_instruction: str = "passage: " + """Instruction used to embed documents.""" + query_instruction: str = "query: " + """Instruction used to embed the query.""" + model_kwargs: Optional[dict] = None + """Other model keyword args""" + + deepinfra_api_token: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + deepinfra_api_token = get_from_dict_or_env( + values, "deepinfra_api_token", "DEEPINFRA_API_TOKEN" + ) + values["deepinfra_api_token"] = deepinfra_api_token + return values + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {"model_id": self.model_id} + + def _embed(self, input: List[str]) -> List[List[float]]: + _model_kwargs = self.model_kwargs or {} + # HTTP headers for authorization + headers = { + "Authorization": f"bearer {self.deepinfra_api_token}", + "Content-Type": "application/json", + } + # send request + try: + res = requests.post( + f"https://api.deepinfra.com/v1/inference/{self.model_id}", + headers=headers, + json={"inputs": input, "normalize": self.normalize, **_model_kwargs}, + ) + except requests.exceptions.RequestException as e: + raise ValueError(f"Error raised by inference endpoint: {e}") + + if res.status_code != 200: + raise ValueError( + "Error raised by inference API HTTP code: %s, %s" + % (res.status_code, res.text) + ) + try: + t = res.json() + embeddings = t["embeddings"] + except requests.exceptions.JSONDecodeError as e: + raise ValueError( + f"Error raised by inference API: {e}.\nResponse: {res.text}" + ) + + return embeddings + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed documents using a Deep Infra deployed embedding model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + instruction_pairs = [f"{self.query_instruction}{text}" for text in texts] + embeddings = self._embed(instruction_pairs) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Embed a query using a Deep Infra deployed embedding model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + instruction_pair = f"{self.query_instruction}{text}" + embedding = self._embed([instruction_pair])[0] + return embedding diff --git a/langchain/llms/deepinfra.py b/langchain/llms/deepinfra.py index 0341c86d..6e18f2e2 100644 --- a/langchain/llms/deepinfra.py +++ b/langchain/llms/deepinfra.py @@ -82,20 +82,33 @@ class DeepInfra(LLM): response = di("Tell me a joke.") """ _model_kwargs = self.model_kwargs or {} + # HTTP headers for authorization + headers = { + "Authorization": f"bearer {self.deepinfra_api_token}", + "Content-Type": "application/json", + } - res = requests.post( - f"https://api.deepinfra.com/v1/inference/{self.model_id}", - headers={ - "Authorization": f"bearer {self.deepinfra_api_token}", - "Content-Type": "application/json", - }, - json={"input": prompt, **_model_kwargs}, - ) + try: + res = requests.post( + f"https://api.deepinfra.com/v1/inference/{self.model_id}", + headers=headers, + json={"input": prompt, **_model_kwargs}, + ) + except requests.exceptions.RequestException as e: + raise ValueError(f"Error raised by inference endpoint: {e}") if res.status_code != 200: - raise ValueError("Error raised by inference API") - t = res.json() - text = t["results"][0]["generated_text"] + raise ValueError( + "Error raised by inference API HTTP code: %s, %s" + % (res.status_code, res.text) + ) + try: + t = res.json() + text = t["results"][0]["generated_text"] + except requests.exceptions.JSONDecodeError as e: + raise ValueError( + f"Error raised by inference API: {e}.\nResponse: {res.text}" + ) if stop is not None: # I believe this is required since the stop tokens diff --git a/tests/integration_tests/embeddings/test_deepinfra.py b/tests/integration_tests/embeddings/test_deepinfra.py new file mode 100644 index 00000000..17099615 --- /dev/null +++ b/tests/integration_tests/embeddings/test_deepinfra.py @@ -0,0 +1,19 @@ +"""Test DeepInfra API wrapper.""" + +from langchain.embeddings import DeepInfraEmbeddings + + +def test_deepinfra_call() -> None: + """Test valid call to DeepInfra.""" + deepinfra_emb = DeepInfraEmbeddings(model_id="sentence-transformers/clip-ViT-B-32") + r1 = deepinfra_emb.embed_documents( + [ + "Alpha is the first letter of Greek alphabet", + "Beta is the second letter of Greek alphabet", + ] + ) + assert len(r1) == 2 + assert len(r1[0]) == 512 + assert len(r1[1]) == 512 + r2 = deepinfra_emb.embed_query("What is the third letter of Greek alphabet") + assert len(r2) == 512 diff --git a/tests/integration_tests/llms/test_deepinfra.py b/tests/integration_tests/llms/test_deepinfra.py new file mode 100644 index 00000000..502a9bbb --- /dev/null +++ b/tests/integration_tests/llms/test_deepinfra.py @@ -0,0 +1,10 @@ +"""Test DeepInfra API wrapper.""" + +from langchain.llms.deepinfra import DeepInfra + + +def test_deepinfra_call() -> None: + """Test valid call to DeepInfra.""" + llm = DeepInfra(model_id="google/flan-t5-small") + output = llm("What is 2 + 2?") + assert isinstance(output, str)