mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
VoyageEmbeddings (#12608)
- **Description:** Integrate VoyageEmbeddings into LangChain, with tests and docs - **Issue:** N/A - **Dependencies:** N/A - **Tag maintainer:** N/A - **Twitter handle:** @Voyage_AI_ --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
92bf40a921
commit
1dbb77d7db
228
docs/docs/integrations/text_embedding/voyageai.ipynb
Normal file
228
docs/docs/integrations/text_embedding/voyageai.ipynb
Normal file
@ -0,0 +1,228 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "278b6c63",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Voyage AI\n",
|
||||
"\n",
|
||||
"Let's load the Voyage Embedding class."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "0be1af71",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import VoyageEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "137cfde9-b88c-409a-9394-a9e31a6bf30d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "2c66e5da",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = VoyageEmbeddings(voyage_api_key=\"[ Your Voyage API key ]\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "459dffb3-9bff-41f2-8507-642de7431b2d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Prepare the documents and use `embed_documents` to get their embeddings."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "c85e948f-85fd-4d56-8d21-6e2f7e65cab8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"documents = [\n",
|
||||
" \"Caching embeddings enables the storage or temporary caching of embeddings, eliminating the necessity to recompute them each time.\",\n",
|
||||
" \"An LLMChain is a chain that composes basic LLM functionality. It consists of a PromptTemplate and a language model (either an LLM or chat model). It formats the prompt template using the input key values provided (and also memory key values, if available), passes the formatted string to LLM and returns the LLM output.\",\n",
|
||||
" \"A Runnable represents a generic unit of work that can be invoked, batched, streamed, and/or transformed.\",\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "5a77a12d-6ac6-4ab8-b103-80ff24487019",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"documents_embds = embeddings.embed_documents(documents)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "2c89167c-816c-487e-8704-90908a4190bb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[0.0562174916267395,\n",
|
||||
" 0.018221192061901093,\n",
|
||||
" 0.0025736060924828053,\n",
|
||||
" -0.009720131754875183,\n",
|
||||
" 0.04108370840549469]"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"documents_embds[0][:5]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f8d796d1-4ced-44d3-81bf-282721edb6bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Similarly, use `embed_query` to embed the query."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "bfb6142c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What's an LLMChain?\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "91bc875d-829b-4c3d-8e6f-fc2dda30a3bd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_embd = embeddings.embed_query(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "a4b0d49e-0c73-44b6-aed5-5b426564e085",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[-0.0052348352037370205,\n",
|
||||
" -0.040072452276945114,\n",
|
||||
" 0.0033957737032324076,\n",
|
||||
" 0.01763271726667881,\n",
|
||||
" -0.019235141575336456]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"query_embd[:5]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b16ddbb2-61f0-49ec-92c3-a6f236d9517f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## A minimalist retrieval system"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5464cb0a-6967-4f1e-ac7c-0aab80b2795a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The main feature of the embeddings is that the cosine similarity between two embeddings captures the semantic relatedness of the corresponding original passages. This allows us to use the embeddings to do semantic retrieval / search."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0bd3ad2-ca68-4e75-9172-76aea28ba46e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" We can find a few closest embeddings in the documents embeddings based on the cosine similarity, and retrieve the corresponding document using the `KNNRetriever` class from LangChain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "0a3fc579-85a9-4bd0-a944-4e32ac62e2d4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"An LLMChain is a chain that composes basic LLM functionality. It consists of a PromptTemplate and a language model (either an LLM or chat model). It formats the prompt template using the input key values provided (and also memory key values, if available), passes the formatted string to LLM and returns the LLM output.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.retrievers import KNNRetriever\n",
|
||||
"\n",
|
||||
"retriever = KNNRetriever.from_texts(documents, embeddings)\n",
|
||||
"\n",
|
||||
"# retrieve the most relevant documents\n",
|
||||
"result = retriever.get_relevant_documents(query)\n",
|
||||
"top1_retrieved_doc = result[0].page_content # return the top1 retrieved result\n",
|
||||
"\n",
|
||||
"print(top1_retrieved_doc)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.18"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "e971737741ff4ec9aff7dc6155a1060a59a8a6d52c757dbbe66bf8ee389494b1"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -64,6 +64,7 @@ from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddi
|
||||
from langchain.embeddings.spacy_embeddings import SpacyEmbeddings
|
||||
from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings
|
||||
from langchain.embeddings.vertexai import VertexAIEmbeddings
|
||||
from langchain.embeddings.voyageai import VoyageEmbeddings
|
||||
from langchain.embeddings.xinference import XinferenceEmbeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -115,6 +116,7 @@ __all__ = [
|
||||
"OllamaEmbeddings",
|
||||
"QianfanEmbeddingsEndpoint",
|
||||
"JohnSnowLabsEmbeddings",
|
||||
"VoyageEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
|
158
libs/langchain/langchain/embeddings/voyageai.py
Normal file
158
libs/langchain/langchain/embeddings/voyageai.py
Normal file
@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import requests
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_retry_decorator(embeddings: VoyageEmbeddings) -> Callable[[Any], Any]:
|
||||
min_seconds = 4
|
||||
max_seconds = 10
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(embeddings.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def _check_response(response: dict) -> dict:
|
||||
if "data" not in response:
|
||||
raise RuntimeError(f"Voyage API Error. Message: {json.dumps(response)}")
|
||||
return response
|
||||
|
||||
|
||||
def embed_with_retry(embeddings: VoyageEmbeddings, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the embedding call."""
|
||||
retry_decorator = _create_retry_decorator(embeddings)
|
||||
|
||||
@retry_decorator
|
||||
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = requests.post(**kwargs)
|
||||
return _check_response(response.json())
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
|
||||
|
||||
class VoyageEmbeddings(BaseModel, Embeddings):
|
||||
"""Voyage embedding models.
|
||||
|
||||
To use, you should have the environment variable ``VOYAGE_API_KEY`` set with
|
||||
your API key or pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import VoyageEmbeddings
|
||||
|
||||
voyage = VoyageEmbeddings(voyage_api_key="your-api-key")
|
||||
text = "This is a test query."
|
||||
query_result = voyage.embed_query(text)
|
||||
"""
|
||||
|
||||
model: str = "voyage-01"
|
||||
voyage_api_base: str = "https://api.voyageai.com/v1/embeddings"
|
||||
voyage_api_key: Optional[SecretStr] = None
|
||||
batch_size: int = 8
|
||||
"""Maximum number of texts to embed in each API request."""
|
||||
max_retries: int = 6
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||
"""Timeout in seconds for the API request."""
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a progress bar when embedding. Must have tqdm installed if set
|
||||
to True."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["voyage_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY")
|
||||
)
|
||||
return values
|
||||
|
||||
def _invocation_params(self, input: List[str]) -> Dict:
|
||||
api_key = cast(SecretStr, self.voyage_api_key).get_secret_value()
|
||||
params = {
|
||||
"url": self.voyage_api_base,
|
||||
"headers": {"Authorization": f"Bearer {api_key}"},
|
||||
"json": {"model": self.model, "input": input},
|
||||
"timeout": self.request_timeout,
|
||||
}
|
||||
return params
|
||||
|
||||
def _get_embeddings(self, texts: List[str], batch_size: int) -> List[List[float]]:
|
||||
embeddings: List[List[float]] = []
|
||||
|
||||
if self.show_progress_bar:
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have tqdm installed if `show_progress_bar` is set to True. "
|
||||
"Please install with `pip install tqdm`."
|
||||
) from e
|
||||
|
||||
_iter = tqdm(range(0, len(texts), batch_size))
|
||||
else:
|
||||
_iter = range(0, len(texts), batch_size)
|
||||
|
||||
for i in _iter:
|
||||
response = embed_with_retry(
|
||||
self, **self._invocation_params(input=texts[i : i + batch_size])
|
||||
)
|
||||
embeddings.extend(r["embedding"] for r in response["data"])
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Voyage Embedding endpoint for embedding search docs.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
return self._get_embeddings(texts, batch_size=self.batch_size)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Voyage Embedding endpoint for embedding query text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
@ -0,0 +1,33 @@
|
||||
"""Test voyage embeddings."""
|
||||
from langchain.embeddings.voyageai import VoyageEmbeddings
|
||||
|
||||
# Please set VOYAGE_API_KEY in the environment variables
|
||||
MODEL = "voyage-01"
|
||||
|
||||
|
||||
def test_voyagi_embedding_documents() -> None:
|
||||
"""Test voyage embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = VoyageEmbeddings(model=MODEL)
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 1024
|
||||
|
||||
|
||||
def test_voyage_embedding_documents_multiple() -> None:
|
||||
"""Test voyage embeddings."""
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
embedding = VoyageEmbeddings(model=MODEL, batch_size=2)
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 3
|
||||
assert len(output[0]) == 1024
|
||||
assert len(output[1]) == 1024
|
||||
assert len(output[2]) == 1024
|
||||
|
||||
|
||||
def test_voyage_embedding_query() -> None:
|
||||
"""Test voyage embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = VoyageEmbeddings(model=MODEL)
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 1024
|
Loading…
Reference in New Issue
Block a user