You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/embeddings/infinity_local.py

159 lines
5.1 KiB
Python

"""written under MIT Licence, Michael Feil 2023."""
import asyncio
from logging import getLogger
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
__all__ = ["InfinityEmbeddingsLocal"]
logger = getLogger(__name__)
class InfinityEmbeddingsLocal(BaseModel, Embeddings):
"""Optimized Infinity embedding models.
https://github.com/michaelfeil/infinity
This class deploys a local Infinity instance to embed text.
The class requires async usage.
Infinity is a class to interact with Embedding Models on https://github.com/michaelfeil/infinity
Example:
.. code-block:: python
from langchain_community.embeddings import InfinityEmbeddingsLocal
async with InfinityEmbeddingsLocal(
model="BAAI/bge-small-en-v1.5",
revision=None,
device="cpu",
) as embedder:
embeddings = await engine.aembed_documents(["text1", "text2"])
"""
model: str
"Underlying model id from huggingface, e.g. BAAI/bge-small-en-v1.5"
revision: Optional[str] = None
"Model version, the commit hash from huggingface"
batch_size: int = 32
"Internal batch size for inference, e.g. 32"
device: str = "auto"
"Device to use for inference, e.g. 'cpu' or 'cuda', or 'mps'"
backend: str = "torch"
"Backend for inference, e.g. 'torch' (recommended for ROCm/Nvidia)"
" or 'optimum' for onnx/tensorrt"
model_warmup: bool = True
"Warmup the model with the max batch size."
engine: Any = None #: :meta private:
"""Infinity's AsyncEmbeddingEngine."""
# LLM call kwargs
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
from infinity_emb import AsyncEmbeddingEngine # type: ignore
except ImportError:
raise ImportError(
"Please install the "
"`pip install 'infinity_emb[optimum,torch]>=0.0.24'` "
"package to use the InfinityEmbeddingsLocal."
)
logger.debug(f"Using InfinityEmbeddingsLocal with kwargs {values}")
values["engine"] = AsyncEmbeddingEngine(
model_name_or_path=values["model"],
device=values["device"],
revision=values["revision"],
model_warmup=values["model_warmup"],
batch_size=values["batch_size"],
engine=values["backend"],
)
return values
async def __aenter__(self) -> None:
"""start the background worker.
recommended usage is with the async with statement.
async with InfinityEmbeddingsLocal(
model="BAAI/bge-small-en-v1.5",
revision=None,
device="cpu",
) as embedder:
embeddings = await engine.aembed_documents(["text1", "text2"])
"""
await self.engine.__aenter__()
async def __aexit__(self, *args: Any) -> None:
"""stop the background worker,
required to free references to the pytorch model."""
await self.engine.__aexit__(*args)
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Async call out to Infinity's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
if not self.engine.running:
logger.warning(
"Starting Infinity engine on the fly. This is not recommended."
"Please start the engine before using it."
)
async with self:
# spawning threadpool for multithreaded encode, tokenization
embeddings, _ = await self.engine.embed(texts)
# stopping threadpool on exit
logger.warning("Stopped infinity engine after usage.")
else:
embeddings, _ = await self.engine.embed(texts)
return embeddings
async def aembed_query(self, text: str) -> List[float]:
"""Async call out to Infinity's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
embeddings = await self.aembed_documents([text])
return embeddings[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
This method is async only.
"""
logger.warning(
"This method is async only. "
"Please use the async version `await aembed_documents`."
)
return asyncio.run(self.aembed_documents(texts))
def embed_query(self, text: str) -> List[float]:
""" """
logger.warning(
"This method is async only."
" Please use the async version `await aembed_query`."
)
return asyncio.run(self.aembed_query(text))