mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
131 lines
4.1 KiB
Python
131 lines
4.1 KiB
Python
|
import logging
|
||
|
import os
|
||
|
from typing import Iterable, List, Optional
|
||
|
|
||
|
import voyageai # type: ignore
|
||
|
from langchain_core.embeddings import Embeddings
|
||
|
from langchain_core.pydantic_v1 import (
|
||
|
BaseModel,
|
||
|
Extra,
|
||
|
Field,
|
||
|
SecretStr,
|
||
|
root_validator,
|
||
|
)
|
||
|
from langchain_core.utils import convert_to_secret_str
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class VoyageAIEmbeddings(BaseModel, Embeddings):
|
||
|
"""VoyageAIEmbeddings embedding model.
|
||
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
from langchain_voyageai import VoyageAIEmbeddings
|
||
|
|
||
|
model = VoyageAIEmbeddings()
|
||
|
"""
|
||
|
|
||
|
_client: voyageai.Client = Field(exclude=True)
|
||
|
_aclient: voyageai.client_async.AsyncClient = Field(exclude=True)
|
||
|
model: str
|
||
|
batch_size: int
|
||
|
show_progress_bar: bool = False
|
||
|
truncation: Optional[bool] = None
|
||
|
voyage_api_key: Optional[SecretStr] = None
|
||
|
|
||
|
class Config:
|
||
|
extra = Extra.forbid
|
||
|
|
||
|
@root_validator(pre=True)
|
||
|
def default_values(cls, values: dict) -> dict:
|
||
|
"""Set default batch size based on model"""
|
||
|
|
||
|
model = values.get("model")
|
||
|
batch_size = values.get("batch_size")
|
||
|
if batch_size is None:
|
||
|
print("batch size", batch_size)
|
||
|
values["batch_size"] = 72 if model in ["voyage-2", "voyage-02"] else 7
|
||
|
return values
|
||
|
|
||
|
@root_validator()
|
||
|
def validate_environment(cls, values: dict) -> dict:
|
||
|
"""Validate that VoyageAI credentials exist in environment."""
|
||
|
voyage_api_key = values.get("voyage_api_key") or os.getenv(
|
||
|
"VOYAGE_API_KEY", None
|
||
|
)
|
||
|
if voyage_api_key:
|
||
|
api_key_secretstr = convert_to_secret_str(voyage_api_key)
|
||
|
values["voyage_api_key"] = api_key_secretstr
|
||
|
|
||
|
api_key_str = api_key_secretstr.get_secret_value()
|
||
|
else:
|
||
|
api_key_str = None
|
||
|
values["_client"] = voyageai.Client(api_key=api_key_str)
|
||
|
values["_aclient"] = voyageai.client_async.AsyncClient(api_key=api_key_str)
|
||
|
return values
|
||
|
|
||
|
def _get_batch_iterator(self, texts: List[str]) -> Iterable:
|
||
|
if self.show_progress_bar:
|
||
|
try:
|
||
|
from tqdm.auto import tqdm # type: ignore
|
||
|
except ImportError as e:
|
||
|
raise ImportError(
|
||
|
"Must have tqdm installed if `show_progress_bar` is set to True. "
|
||
|
"Please install with `pip install tqdm`."
|
||
|
) from e
|
||
|
|
||
|
_iter = tqdm(range(0, len(texts), self.batch_size))
|
||
|
else:
|
||
|
_iter = range(0, len(texts), self.batch_size) # type: ignore
|
||
|
|
||
|
return _iter
|
||
|
|
||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||
|
"""Embed search docs."""
|
||
|
embeddings: List[List[float]] = []
|
||
|
|
||
|
_iter = self._get_batch_iterator(texts)
|
||
|
for i in _iter:
|
||
|
embeddings.extend(
|
||
|
self._client.embed(
|
||
|
texts[i : i + self.batch_size],
|
||
|
model=self.model,
|
||
|
input_type="document",
|
||
|
truncation=self.truncation,
|
||
|
).embeddings
|
||
|
)
|
||
|
|
||
|
return embeddings
|
||
|
|
||
|
def embed_query(self, text: str) -> List[float]:
|
||
|
"""Embed query text."""
|
||
|
return self._client.embed(
|
||
|
[text], model=self.model, input_type="query", truncation=self.truncation
|
||
|
).embeddings[0]
|
||
|
|
||
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||
|
embeddings: List[List[float]] = []
|
||
|
|
||
|
_iter = self._get_batch_iterator(texts)
|
||
|
for i in _iter:
|
||
|
r = await self._aclient.embed(
|
||
|
texts[i : i + self.batch_size],
|
||
|
model=self.model,
|
||
|
input_type="document",
|
||
|
truncation=self.truncation,
|
||
|
)
|
||
|
embeddings.extend(r.embeddings)
|
||
|
|
||
|
return embeddings
|
||
|
|
||
|
async def aembed_query(self, text: str) -> List[float]:
|
||
|
r = await self._aclient.embed(
|
||
|
[text],
|
||
|
model=self.model,
|
||
|
input_type="query",
|
||
|
truncation=self.truncation,
|
||
|
)
|
||
|
return r.embeddings[0]
|