mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/embaas (#6010)
Co-authored-by: Julius Lipp <43986145+juliuslipp@users.noreply.github.com>
This commit is contained in:
parent
232faba796
commit
a7227ee01b
159
docs/modules/models/text_embedding/examples/embaas.ipynb
Normal file
159
docs/modules/models/text_embedding/examples/embaas.ipynb
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"[embaas](https://embaas.io) is a fully managed NLP API service that offers features like embedding generation, document text extraction, document to embeddings and more. You can choose a [variety of pre-trained models](https://embaas.io/docs/models/embeddings).\n",
|
||||||
|
"\n",
|
||||||
|
"In this tutorial, we will show you how to use the embaas Embeddings API to generate embeddings for a given text.\n",
|
||||||
|
"\n",
|
||||||
|
"### Prerequisites\n",
|
||||||
|
"Create your free embaas account at [https://embaas.io/register](https://embaas.io/register) and generate an [API key](https://embaas.io/dashboard/api-keys)."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Set API key\n",
|
||||||
|
"embaas_api_key = \"YOUR_API_KEY\"\n",
|
||||||
|
"# or set environment variable\n",
|
||||||
|
"os.environ[\"EMBAAS_API_KEY\"] = \"YOUR_API_KEY\""
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings import EmbaasEmbeddings"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embeddings = EmbaasEmbeddings()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create embeddings for a single document\n",
|
||||||
|
"doc_text = \"This is a test document.\"\n",
|
||||||
|
"doc_text_embedding = embeddings.embed_query(doc_text)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"start_time": "2023-06-10T11:17:55.938517Z",
|
||||||
|
"end_time": "2023-06-10T11:17:55.940265Z"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Print created embedding\n",
|
||||||
|
"print(doc_text_embedding)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create embeddings for multiple documents\n",
|
||||||
|
"doc_texts = [\"This is a test document.\", \"This is another test document.\"]\n",
|
||||||
|
"doc_texts_embeddings = embeddings.embed_documents(doc_texts)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"start_time": "2023-06-10T11:19:25.235320Z",
|
||||||
|
"end_time": "2023-06-10T11:19:25.237161Z"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Print created embeddings\n",
|
||||||
|
"for i, doc_text_embedding in enumerate(doc_texts_embeddings):\n",
|
||||||
|
" print(f\"Embedding for document {i + 1}: {doc_text_embedding}\")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Using a different model and/or custom instruction\n",
|
||||||
|
"embeddings = EmbaasEmbeddings(model=\"instructor-large\", instruction=\"Represent the Wikipedia document for retrieval\")"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"start_time": "2023-06-10T11:22:26.138357Z",
|
||||||
|
"end_time": "2023-06-10T11:22:26.139769Z"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"For more detailed information about the embaas Embeddings API, please refer to [the official embaas API documentation](https://embaas.io/api-reference)."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
@ -10,6 +10,7 @@ from langchain.embeddings.bedrock import BedrockEmbeddings
|
|||||||
from langchain.embeddings.cohere import CohereEmbeddings
|
from langchain.embeddings.cohere import CohereEmbeddings
|
||||||
from langchain.embeddings.deepinfra import DeepInfraEmbeddings
|
from langchain.embeddings.deepinfra import DeepInfraEmbeddings
|
||||||
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
||||||
|
from langchain.embeddings.embaas import EmbaasEmbeddings
|
||||||
from langchain.embeddings.fake import FakeEmbeddings
|
from langchain.embeddings.fake import FakeEmbeddings
|
||||||
from langchain.embeddings.google_palm import GooglePalmEmbeddings
|
from langchain.embeddings.google_palm import GooglePalmEmbeddings
|
||||||
from langchain.embeddings.huggingface import (
|
from langchain.embeddings.huggingface import (
|
||||||
@ -60,6 +61,7 @@ __all__ = [
|
|||||||
"VertexAIEmbeddings",
|
"VertexAIEmbeddings",
|
||||||
"BedrockEmbeddings",
|
"BedrockEmbeddings",
|
||||||
"DeepInfraEmbeddings",
|
"DeepInfraEmbeddings",
|
||||||
|
"EmbaasEmbeddings",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
140
langchain/embeddings/embaas.py
Normal file
140
langchain/embeddings/embaas.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
"""Wrapper around embaas embeddings API."""
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
# Currently supported maximum batch size for embedding requests
|
||||||
|
MAX_BATCH_SIZE = 256
|
||||||
|
EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbaasEmbeddingsPayload(TypedDict):
|
||||||
|
"""Payload for the embaas embeddings API."""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
texts: List[str]
|
||||||
|
instruction: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbaasEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""Wrapper around embaas's embedding service.
|
||||||
|
|
||||||
|
To use, you should have the
|
||||||
|
environment variable ``EMBAAS_API_KEY`` set with your API key, or pass
|
||||||
|
it as a named parameter to the constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Initialise with default model and instruction
|
||||||
|
from langchain.llms import EmbaasEmbeddings
|
||||||
|
emb = EmbaasEmbeddings()
|
||||||
|
|
||||||
|
# Initialise with custom model and instruction
|
||||||
|
from langchain.llms import EmbaasEmbeddings
|
||||||
|
emb_model = "instructor-large"
|
||||||
|
emb_inst = "Represent the Wikipedia document for retrieval"
|
||||||
|
emb = EmbaasEmbeddings(
|
||||||
|
model=emb_model,
|
||||||
|
instruction=emb_inst,
|
||||||
|
embaas_api_key="your-api-key"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str = "e5-large-v2"
|
||||||
|
"""The model used for embeddings."""
|
||||||
|
instruction: Optional[str] = None
|
||||||
|
"""Instruction used for domain-specific embeddings."""
|
||||||
|
api_url: str = EMBAAS_API_URL
|
||||||
|
"""The URL for the embaas embeddings API."""
|
||||||
|
embaas_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
embaas_api_key = get_from_dict_or_env(
|
||||||
|
values, "embaas_api_key", "EMBAAS_API_KEY"
|
||||||
|
)
|
||||||
|
values["embaas_api_key"] = embaas_api_key
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying params."""
|
||||||
|
return {"model": self.model, "instruction": self.instruction}
|
||||||
|
|
||||||
|
def _generate_payload(self, texts: List[str]) -> EmbaasEmbeddingsPayload:
|
||||||
|
"""Generates payload for the API request."""
|
||||||
|
payload = EmbaasEmbeddingsPayload(texts=texts, model=self.model)
|
||||||
|
if self.instruction:
|
||||||
|
payload["instruction"] = self.instruction
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
|
||||||
|
"""Sends a request to the Embaas API and handles the response."""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.embaas_api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
parsed_response = response.json()
|
||||||
|
embeddings = [item["embedding"] for item in parsed_response["data"]]
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Generate embeddings using the Embaas API."""
|
||||||
|
payload = self._generate_payload(texts)
|
||||||
|
try:
|
||||||
|
return self._handle_request(payload)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
if e.response is None or not e.response.text:
|
||||||
|
raise ValueError(f"Error raised by embaas embeddings API: {e}")
|
||||||
|
|
||||||
|
parsed_response = e.response.json()
|
||||||
|
if "message" in parsed_response:
|
||||||
|
raise ValueError(
|
||||||
|
"Validation Error raised by embaas embeddings API:"
|
||||||
|
f"{parsed_response['message']}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Get embeddings for a list of texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to get embeddings for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
batches = [
|
||||||
|
texts[i : i + MAX_BATCH_SIZE] for i in range(0, len(texts), MAX_BATCH_SIZE)
|
||||||
|
]
|
||||||
|
embeddings = [self._generate_embeddings(batch) for batch in batches]
|
||||||
|
# flatten the list of lists into a single list
|
||||||
|
return [embedding for batch in embeddings for embedding in batch]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Get embeddings for a single text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to get embeddings for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings.
|
||||||
|
"""
|
||||||
|
return self.embed_documents([text])[0]
|
58
tests/integration_tests/embeddings/test_embaas.py
Normal file
58
tests/integration_tests/embeddings/test_embaas.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
"""Test embaas embeddings."""
|
||||||
|
import responses
|
||||||
|
|
||||||
|
from langchain.embeddings.embaas import EMBAAS_API_URL, EmbaasEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def test_embaas_embed_documents() -> None:
|
||||||
|
"""Test embaas embeddings with multiple texts."""
|
||||||
|
texts = ["foo bar", "bar foo", "foo"]
|
||||||
|
embedding = EmbaasEmbeddings()
|
||||||
|
output = embedding.embed_documents(texts)
|
||||||
|
assert len(output) == 3
|
||||||
|
assert len(output[0]) == 1024
|
||||||
|
assert len(output[1]) == 1024
|
||||||
|
assert len(output[2]) == 1024
|
||||||
|
|
||||||
|
|
||||||
|
def test_embaas_embed_query() -> None:
|
||||||
|
"""Test embaas embeddings with multiple texts."""
|
||||||
|
text = "foo"
|
||||||
|
embeddings = EmbaasEmbeddings()
|
||||||
|
output = embeddings.embed_query(text)
|
||||||
|
assert len(output) == 1024
|
||||||
|
|
||||||
|
|
||||||
|
def test_embaas_embed_query_instruction() -> None:
|
||||||
|
"""Test embaas embeddings with a different instruction."""
|
||||||
|
text = "Test"
|
||||||
|
instruction = "query"
|
||||||
|
embeddings = EmbaasEmbeddings(instruction=instruction)
|
||||||
|
output = embeddings.embed_query(text)
|
||||||
|
assert len(output) == 1024
|
||||||
|
|
||||||
|
|
||||||
|
def test_embaas_embed_query_model() -> None:
|
||||||
|
"""Test embaas embeddings with a different model."""
|
||||||
|
text = "Test"
|
||||||
|
model = "instructor-large"
|
||||||
|
instruction = "Represent the query for retrieval"
|
||||||
|
embeddings = EmbaasEmbeddings(model=model, instruction=instruction)
|
||||||
|
output = embeddings.embed_query(text)
|
||||||
|
assert len(output) == 768
|
||||||
|
|
||||||
|
|
||||||
|
@responses.activate
|
||||||
|
def test_embaas_embed_documents_response() -> None:
|
||||||
|
"""Test embaas embeddings with multiple texts."""
|
||||||
|
responses.add(
|
||||||
|
responses.POST,
|
||||||
|
EMBAAS_API_URL,
|
||||||
|
json={"data": [{"embedding": [0.0] * 1024}]},
|
||||||
|
status=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = "asd"
|
||||||
|
embeddings = EmbaasEmbeddings()
|
||||||
|
output = embeddings.embed_query(text)
|
||||||
|
assert len(output) == 1024
|
Loading…
Reference in New Issue
Block a user