Harrison/embaas (#6010)

Co-authored-by: Julius Lipp <43986145+juliuslipp@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-06-11 13:35:14 -07:00 committed by GitHub
parent 232faba796
commit a7227ee01b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 359 additions and 0 deletions

View File

@ -0,0 +1,159 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"[embaas](https://embaas.io) is a fully managed NLP API service that offers features like embedding generation, document text extraction, document to embeddings and more. You can choose a [variety of pre-trained models](https://embaas.io/docs/models/embeddings).\n",
"\n",
"In this tutorial, we will show you how to use the embaas Embeddings API to generate embeddings for a given text.\n",
"\n",
"### Prerequisites\n",
"Create your free embaas account at [https://embaas.io/register](https://embaas.io/register) and generate an [API key](https://embaas.io/dashboard/api-keys)."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Set API key\n",
"embaas_api_key = \"YOUR_API_KEY\"\n",
"# or set environment variable\n",
"os.environ[\"EMBAAS_API_KEY\"] = \"YOUR_API_KEY\""
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from langchain.embeddings import EmbaasEmbeddings"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"embeddings = EmbaasEmbeddings()"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Create embeddings for a single document\n",
"doc_text = \"This is a test document.\"\n",
"doc_text_embedding = embeddings.embed_query(doc_text)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-06-10T11:17:55.938517Z",
"end_time": "2023-06-10T11:17:55.940265Z"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Print created embedding\n",
"print(doc_text_embedding)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [],
"source": [
"# Create embeddings for multiple documents\n",
"doc_texts = [\"This is a test document.\", \"This is another test document.\"]\n",
"doc_texts_embeddings = embeddings.embed_documents(doc_texts)"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-06-10T11:19:25.235320Z",
"end_time": "2023-06-10T11:19:25.237161Z"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# Print created embeddings\n",
"for i, doc_text_embedding in enumerate(doc_texts_embeddings):\n",
" print(f\"Embedding for document {i + 1}: {doc_text_embedding}\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [],
"source": [
"# Using a different model and/or custom instruction\n",
"embeddings = EmbaasEmbeddings(model=\"instructor-large\", instruction=\"Represent the Wikipedia document for retrieval\")"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"start_time": "2023-06-10T11:22:26.138357Z",
"end_time": "2023-06-10T11:22:26.139769Z"
}
}
},
{
"cell_type": "markdown",
"source": [
"For more detailed information about the embaas Embeddings API, please refer to [the official embaas API documentation](https://embaas.io/api-reference)."
],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -10,6 +10,7 @@ from langchain.embeddings.bedrock import BedrockEmbeddings
from langchain.embeddings.cohere import CohereEmbeddings from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.deepinfra import DeepInfraEmbeddings from langchain.embeddings.deepinfra import DeepInfraEmbeddings
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
from langchain.embeddings.embaas import EmbaasEmbeddings
from langchain.embeddings.fake import FakeEmbeddings from langchain.embeddings.fake import FakeEmbeddings
from langchain.embeddings.google_palm import GooglePalmEmbeddings from langchain.embeddings.google_palm import GooglePalmEmbeddings
from langchain.embeddings.huggingface import ( from langchain.embeddings.huggingface import (
@ -60,6 +61,7 @@ __all__ = [
"VertexAIEmbeddings", "VertexAIEmbeddings",
"BedrockEmbeddings", "BedrockEmbeddings",
"DeepInfraEmbeddings", "DeepInfraEmbeddings",
"EmbaasEmbeddings",
] ]

View File

@ -0,0 +1,140 @@
"""Wrapper around embaas embeddings API."""
from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import BaseModel, Extra, root_validator
from typing_extensions import NotRequired, TypedDict
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
# Currently supported maximum batch size for embedding requests
MAX_BATCH_SIZE = 256
EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/"
class EmbaasEmbeddingsPayload(TypedDict):
"""Payload for the embaas embeddings API."""
model: str
texts: List[str]
instruction: NotRequired[str]
class EmbaasEmbeddings(BaseModel, Embeddings):
"""Wrapper around embaas's embedding service.
To use, you should have the
environment variable ``EMBAAS_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
# Initialise with default model and instruction
from langchain.llms import EmbaasEmbeddings
emb = EmbaasEmbeddings()
# Initialise with custom model and instruction
from langchain.llms import EmbaasEmbeddings
emb_model = "instructor-large"
emb_inst = "Represent the Wikipedia document for retrieval"
emb = EmbaasEmbeddings(
model=emb_model,
instruction=emb_inst,
embaas_api_key="your-api-key"
)
"""
model: str = "e5-large-v2"
"""The model used for embeddings."""
instruction: Optional[str] = None
"""Instruction used for domain-specific embeddings."""
api_url: str = EMBAAS_API_URL
"""The URL for the embaas embeddings API."""
embaas_api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
embaas_api_key = get_from_dict_or_env(
values, "embaas_api_key", "EMBAAS_API_KEY"
)
values["embaas_api_key"] = embaas_api_key
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying params."""
return {"model": self.model, "instruction": self.instruction}
def _generate_payload(self, texts: List[str]) -> EmbaasEmbeddingsPayload:
"""Generates payload for the API request."""
payload = EmbaasEmbeddingsPayload(texts=texts, model=self.model)
if self.instruction:
payload["instruction"] = self.instruction
return payload
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
"""Sends a request to the Embaas API and handles the response."""
headers = {
"Authorization": f"Bearer {self.embaas_api_key}",
"Content-Type": "application/json",
}
response = requests.post(self.api_url, headers=headers, json=payload)
response.raise_for_status()
parsed_response = response.json()
embeddings = [item["embedding"] for item in parsed_response["data"]]
return embeddings
def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings using the Embaas API."""
payload = self._generate_payload(texts)
try:
return self._handle_request(payload)
except requests.exceptions.RequestException as e:
if e.response is None or not e.response.text:
raise ValueError(f"Error raised by embaas embeddings API: {e}")
parsed_response = e.response.json()
if "message" in parsed_response:
raise ValueError(
"Validation Error raised by embaas embeddings API:"
f"{parsed_response['message']}"
)
raise
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for a list of texts.
Args:
texts: The list of texts to get embeddings for.
Returns:
List of embeddings, one for each text.
"""
batches = [
texts[i : i + MAX_BATCH_SIZE] for i in range(0, len(texts), MAX_BATCH_SIZE)
]
embeddings = [self._generate_embeddings(batch) for batch in batches]
# flatten the list of lists into a single list
return [embedding for batch in embeddings for embedding in batch]
def embed_query(self, text: str) -> List[float]:
"""Get embeddings for a single text.
Args:
text: The text to get embeddings for.
Returns:
List of embeddings.
"""
return self.embed_documents([text])[0]

View File

@ -0,0 +1,58 @@
"""Test embaas embeddings."""
import responses
from langchain.embeddings.embaas import EMBAAS_API_URL, EmbaasEmbeddings
def test_embaas_embed_documents() -> None:
"""Test embaas embeddings with multiple texts."""
texts = ["foo bar", "bar foo", "foo"]
embedding = EmbaasEmbeddings()
output = embedding.embed_documents(texts)
assert len(output) == 3
assert len(output[0]) == 1024
assert len(output[1]) == 1024
assert len(output[2]) == 1024
def test_embaas_embed_query() -> None:
"""Test embaas embeddings with multiple texts."""
text = "foo"
embeddings = EmbaasEmbeddings()
output = embeddings.embed_query(text)
assert len(output) == 1024
def test_embaas_embed_query_instruction() -> None:
"""Test embaas embeddings with a different instruction."""
text = "Test"
instruction = "query"
embeddings = EmbaasEmbeddings(instruction=instruction)
output = embeddings.embed_query(text)
assert len(output) == 1024
def test_embaas_embed_query_model() -> None:
"""Test embaas embeddings with a different model."""
text = "Test"
model = "instructor-large"
instruction = "Represent the query for retrieval"
embeddings = EmbaasEmbeddings(model=model, instruction=instruction)
output = embeddings.embed_query(text)
assert len(output) == 768
@responses.activate
def test_embaas_embed_documents_response() -> None:
"""Test embaas embeddings with multiple texts."""
responses.add(
responses.POST,
EMBAAS_API_URL,
json={"data": [{"embedding": [0.0] * 1024}]},
status=200,
)
text = "asd"
embeddings = EmbaasEmbeddings()
output = embeddings.embed_query(text)
assert len(output) == 1024