"""written under MIT Licence, Michael Feil 2023.""" import asyncio from logging import getLogger from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator __all__ = ["InfinityEmbeddingsLocal"] logger = getLogger(__name__) class InfinityEmbeddingsLocal(BaseModel, Embeddings): """Optimized Embedding models https://github.com/michaelfeil/infinity This class deploys a local Infinity instance to embed text. The class requires async usage. Infinity is a class to interact with Embedding Models on https://github.com/michaelfeil/infinity Example: .. code-block:: python from langchain_community.embeddings import InfinityEmbeddingsLocal async with InfinityEmbeddingsLocal( model="BAAI/bge-small-en-v1.5", revision=None, device="cpu", ) as embedder: embeddings = await engine.aembed_documents(["text1", "text2"]) """ model: str "Underlying model id from huggingface, e.g. BAAI/bge-small-en-v1.5" revision: Optional[str] = None "Model version, the commit hash from huggingface" batch_size: int = 32 "Internal batch size for inference, e.g. 32" device: str = "auto" "Device to use for inference, e.g. 'cpu' or 'cuda', or 'mps'" backend: str = "torch" "Backend for inference, e.g. 'torch' (recommended for ROCm/Nvidia)" " or 'optimum' for onnx/tensorrt" model_warmup: bool = True "Warmup the model with the max batch size." engine: Any = None #: :meta private: """Infinity's AsyncEmbeddingEngine.""" # LLM call kwargs class Config: """Configuration for this pydantic object.""" extra = Extra.forbid @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" try: from infinity_emb import AsyncEmbeddingEngine # type: ignore except ImportError: raise ImportError( "Please install the " "`pip install 'infinity_emb[optimum,torch]>=0.0.24'` " "package to use the InfinityEmbeddingsLocal." ) logger.debug(f"Using InfinityEmbeddingsLocal with kwargs {values}") values["engine"] = AsyncEmbeddingEngine( model_name_or_path=values["model"], device=values["device"], revision=values["revision"], model_warmup=values["model_warmup"], batch_size=values["batch_size"], engine=values["backend"], ) return values async def __aenter__(self) -> None: """start the background worker. recommended usage is with the async with statement. async with InfinityEmbeddingsLocal( model="BAAI/bge-small-en-v1.5", revision=None, device="cpu", ) as embedder: embeddings = await engine.aembed_documents(["text1", "text2"]) """ await self.engine.__aenter__() async def __aexit__(self, *args: Any) -> None: """stop the background worker, required to free references to the pytorch model.""" await self.engine.__aexit__(*args) async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Async call out to Infinity's embedding endpoint. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ if not self.engine.running: logger.warning( "Starting Infinity engine on the fly. This is not recommended." "Please start the engine before using it." ) async with self: # spawning threadpool for multithreaded encode, tokenization embeddings, _ = await self.engine.embed(texts) # stopping threadpool on exit logger.warning("Stopped infinity engine after usage.") else: embeddings, _ = await self.engine.embed(texts) return embeddings async def aembed_query(self, text: str) -> List[float]: """Async call out to Infinity's embedding endpoint. Args: text: The text to embed. Returns: Embeddings for the text. """ embeddings = await self.aembed_documents([text]) return embeddings[0] def embed_documents(self, texts: List[str]) -> List[List[float]]: """ This method is async only. """ logger.warning( "This method is async only. " "Please use the async version `await aembed_documents`." ) return asyncio.run(self.aembed_documents(texts)) def embed_query(self, text: str) -> List[float]: """ """ logger.warning( "This method is async only." " Please use the async version `await aembed_query`." ) return asyncio.run(self.aembed_query(text))