mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
7cf2d2759d
Added missed docstrings. Format docstings to the consistent form.
159 lines
5.1 KiB
Python
159 lines
5.1 KiB
Python
"""written under MIT Licence, Michael Feil 2023."""
|
|
|
|
import asyncio
|
|
from logging import getLogger
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
|
|
|
__all__ = ["InfinityEmbeddingsLocal"]
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class InfinityEmbeddingsLocal(BaseModel, Embeddings):
|
|
"""Optimized Infinity embedding models.
|
|
|
|
https://github.com/michaelfeil/infinity
|
|
This class deploys a local Infinity instance to embed text.
|
|
The class requires async usage.
|
|
|
|
Infinity is a class to interact with Embedding Models on https://github.com/michaelfeil/infinity
|
|
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.embeddings import InfinityEmbeddingsLocal
|
|
async with InfinityEmbeddingsLocal(
|
|
model="BAAI/bge-small-en-v1.5",
|
|
revision=None,
|
|
device="cpu",
|
|
) as embedder:
|
|
embeddings = await engine.aembed_documents(["text1", "text2"])
|
|
"""
|
|
|
|
model: str
|
|
"Underlying model id from huggingface, e.g. BAAI/bge-small-en-v1.5"
|
|
|
|
revision: Optional[str] = None
|
|
"Model version, the commit hash from huggingface"
|
|
|
|
batch_size: int = 32
|
|
"Internal batch size for inference, e.g. 32"
|
|
|
|
device: str = "auto"
|
|
"Device to use for inference, e.g. 'cpu' or 'cuda', or 'mps'"
|
|
|
|
backend: str = "torch"
|
|
"Backend for inference, e.g. 'torch' (recommended for ROCm/Nvidia)"
|
|
" or 'optimum' for onnx/tensorrt"
|
|
|
|
model_warmup: bool = True
|
|
"Warmup the model with the max batch size."
|
|
|
|
engine: Any = None #: :meta private:
|
|
"""Infinity's AsyncEmbeddingEngine."""
|
|
|
|
# LLM call kwargs
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@root_validator(allow_reuse=True)
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key and python package exists in environment."""
|
|
|
|
try:
|
|
from infinity_emb import AsyncEmbeddingEngine # type: ignore
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install the "
|
|
"`pip install 'infinity_emb[optimum,torch]>=0.0.24'` "
|
|
"package to use the InfinityEmbeddingsLocal."
|
|
)
|
|
logger.debug(f"Using InfinityEmbeddingsLocal with kwargs {values}")
|
|
|
|
values["engine"] = AsyncEmbeddingEngine(
|
|
model_name_or_path=values["model"],
|
|
device=values["device"],
|
|
revision=values["revision"],
|
|
model_warmup=values["model_warmup"],
|
|
batch_size=values["batch_size"],
|
|
engine=values["backend"],
|
|
)
|
|
return values
|
|
|
|
async def __aenter__(self) -> None:
|
|
"""start the background worker.
|
|
recommended usage is with the async with statement.
|
|
|
|
async with InfinityEmbeddingsLocal(
|
|
model="BAAI/bge-small-en-v1.5",
|
|
revision=None,
|
|
device="cpu",
|
|
) as embedder:
|
|
embeddings = await engine.aembed_documents(["text1", "text2"])
|
|
"""
|
|
await self.engine.__aenter__()
|
|
|
|
async def __aexit__(self, *args: Any) -> None:
|
|
"""stop the background worker,
|
|
required to free references to the pytorch model."""
|
|
await self.engine.__aexit__(*args)
|
|
|
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Async call out to Infinity's embedding endpoint.
|
|
|
|
Args:
|
|
texts: The list of texts to embed.
|
|
|
|
Returns:
|
|
List of embeddings, one for each text.
|
|
"""
|
|
if not self.engine.running:
|
|
logger.warning(
|
|
"Starting Infinity engine on the fly. This is not recommended."
|
|
"Please start the engine before using it."
|
|
)
|
|
async with self:
|
|
# spawning threadpool for multithreaded encode, tokenization
|
|
embeddings, _ = await self.engine.embed(texts)
|
|
# stopping threadpool on exit
|
|
logger.warning("Stopped infinity engine after usage.")
|
|
else:
|
|
embeddings, _ = await self.engine.embed(texts)
|
|
return embeddings
|
|
|
|
async def aembed_query(self, text: str) -> List[float]:
|
|
"""Async call out to Infinity's embedding endpoint.
|
|
|
|
Args:
|
|
text: The text to embed.
|
|
|
|
Returns:
|
|
Embeddings for the text.
|
|
"""
|
|
embeddings = await self.aembed_documents([text])
|
|
return embeddings[0]
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""
|
|
This method is async only.
|
|
"""
|
|
logger.warning(
|
|
"This method is async only. "
|
|
"Please use the async version `await aembed_documents`."
|
|
)
|
|
return asyncio.run(self.aembed_documents(texts))
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
""" """
|
|
logger.warning(
|
|
"This method is async only."
|
|
" Please use the async version `await aembed_query`."
|
|
)
|
|
return asyncio.run(self.aembed_query(text))
|