mirror of https://github.com/hwchase17/langchain
community[minor]: infinity embedding local option (#17671)
**drop-in-replacement for sentence-transformers inference.** https://github.com/langchain-ai/langchain/discussions/17670 tldr from the discussion above -> around a 4x-22x speedup over using SentenceTransformers / huggingface embeddings. For more info: https://github.com/michaelfeil/infinity (pure-python dependency) --------- Co-authored-by: Erick Friis <erick@langchain.dev>pull/17879/head
parent
581095b9b5
commit
242981b8f0
@ -0,0 +1,156 @@
|
||||
"""written under MIT Licence, Michael Feil 2023."""
|
||||
|
||||
import asyncio
|
||||
from logging import getLogger
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
|
||||
__all__ = ["InfinityEmbeddingsLocal"]
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class InfinityEmbeddingsLocal(BaseModel, Embeddings):
|
||||
"""Optimized Embedding models https://github.com/michaelfeil/infinity
|
||||
This class deploys a local Infinity instance to embed text.
|
||||
The class requires async usage.
|
||||
|
||||
Infinity is a class to interact with Embedding Models on https://github.com/michaelfeil/infinity
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import InfinityEmbeddingsLocal
|
||||
async with InfinityEmbeddingsLocal(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
revision=None,
|
||||
device="cpu",
|
||||
) as embedder:
|
||||
embeddings = await engine.aembed_documents(["text1", "text2"])
|
||||
"""
|
||||
|
||||
model: str
|
||||
"Underlying model id from huggingface, e.g. BAAI/bge-small-en-v1.5"
|
||||
|
||||
revision: Optional[str] = None
|
||||
"Model version, the commit hash from huggingface"
|
||||
|
||||
batch_size: int = 32
|
||||
"Internal batch size for inference, e.g. 32"
|
||||
|
||||
device: str = "auto"
|
||||
"Device to use for inference, e.g. 'cpu' or 'cuda', or 'mps'"
|
||||
|
||||
backend: str = "torch"
|
||||
"Backend for inference, e.g. 'torch' (recommended for ROCm/Nvidia)"
|
||||
" or 'optimum' for onnx/tensorrt"
|
||||
|
||||
model_warmup: bool = True
|
||||
"Warmup the model with the max batch size."
|
||||
|
||||
engine: Any = None #: :meta private:
|
||||
"""Infinity's AsyncEmbeddingEngine."""
|
||||
|
||||
# LLM call kwargs
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
try:
|
||||
from infinity_emb import AsyncEmbeddingEngine # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install the "
|
||||
"`pip install 'infinity_emb[optimum,torch]>=0.0.24'` "
|
||||
"package to use the InfinityEmbeddingsLocal."
|
||||
)
|
||||
logger.debug(f"Using InfinityEmbeddingsLocal with kwargs {values}")
|
||||
|
||||
values["engine"] = AsyncEmbeddingEngine(
|
||||
model_name_or_path=values["model"],
|
||||
device=values["device"],
|
||||
revision=values["revision"],
|
||||
model_warmup=values["model_warmup"],
|
||||
batch_size=values["batch_size"],
|
||||
engine=values["backend"],
|
||||
)
|
||||
return values
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
"""start the background worker.
|
||||
recommended usage is with the async with statement.
|
||||
|
||||
async with InfinityEmbeddingsLocal(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
revision=None,
|
||||
device="cpu",
|
||||
) as embedder:
|
||||
embeddings = await engine.aembed_documents(["text1", "text2"])
|
||||
"""
|
||||
await self.engine.__aenter__()
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
"""stop the background worker,
|
||||
required to free references to the pytorch model."""
|
||||
await self.engine.__aexit__(*args)
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Async call out to Infinity's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
if not self.engine.running:
|
||||
logger.warning(
|
||||
"Starting Infinity engine on the fly. This is not recommended."
|
||||
"Please start the engine before using it."
|
||||
)
|
||||
async with self:
|
||||
# spawning threadpool for multithreaded encode, tokenization
|
||||
embeddings, _ = await self.engine.embed(texts)
|
||||
# stopping threadpool on exit
|
||||
logger.warning("Stopped infinity engine after usage.")
|
||||
else:
|
||||
embeddings, _ = await self.engine.embed(texts)
|
||||
return embeddings
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Async call out to Infinity's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
embeddings = await self.aembed_documents([text])
|
||||
return embeddings[0]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
This method is async only.
|
||||
"""
|
||||
logger.warning(
|
||||
"This method is async only. "
|
||||
"Please use the async version `await aembed_documents`."
|
||||
)
|
||||
return asyncio.run(self.aembed_documents(texts))
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
""" """
|
||||
logger.warning(
|
||||
"This method is async only."
|
||||
" Please use the async version `await aembed_query`."
|
||||
)
|
||||
return asyncio.run(self.aembed_query(text))
|
@ -0,0 +1,43 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain_community.embeddings.infinity_local import InfinityEmbeddingsLocal
|
||||
|
||||
try:
|
||||
import torch # noqa
|
||||
import infinity_emb # noqa
|
||||
|
||||
IMPORTED_TORCH = True
|
||||
except ImportError:
|
||||
IMPORTED_TORCH = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IMPORTED_TORCH, reason="torch not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_infinity_embeddings() -> None:
|
||||
embedder = InfinityEmbeddingsLocal(
|
||||
model="TaylorAI/bge-micro-v2",
|
||||
device="cpu",
|
||||
backend="torch",
|
||||
revision=None,
|
||||
batch_size=2,
|
||||
model_warmup=False,
|
||||
)
|
||||
|
||||
async with embedder:
|
||||
embeddings = await embedder.aembed_documents(["text1", "text2", "text1"])
|
||||
assert len(embeddings) == 3
|
||||
# model has 384 dim output
|
||||
assert len(embeddings[0]) == 384
|
||||
assert len(embeddings[1]) == 384
|
||||
assert len(embeddings[2]) == 384
|
||||
# assert all different embeddings
|
||||
assert (np.array(embeddings[0]) - np.array(embeddings[1]) != 0).all()
|
||||
# assert identical embeddings, up to floating point error
|
||||
np.testing.assert_array_equal(embeddings[0], embeddings[2])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(test_local_infinity_embeddings())
|
Loading…
Reference in New Issue