mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add embeddings for LocalAI (#8134)
Description: This PR adds embeddings for LocalAI ( https://github.com/go-skynet/LocalAI ), a self-hosted OpenAI drop-in replacement. As LocalAI can re-use OpenAI clients it is mostly following the lines of the OpenAI embeddings, however when embedding documents, it just uses string instead of sending tokens as sending tokens is best-effort depending on the model being used in LocalAI. Sending tokens is also tricky as token id's can mismatch with the model - so it's safer to just send strings in this case. Partly related to: https://github.com/hwchase17/langchain/issues/5256 Dependencies: No new dependencies Twitter: @mudler_it --------- Signed-off-by: mudler <mudler@localai.io> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
d983046f90
commit
ae28568e2a
161
docs/extras/integrations/text_embedding/localai.ipynb
Normal file
161
docs/extras/integrations/text_embedding/localai.ipynb
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "278b6c63",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# LocalAI\n",
|
||||||
|
"\n",
|
||||||
|
"Let's load the LocalAI Embedding class. In order to use the LocalAI Embedding class, you need to have the LocalAI service hosted somewhere and configure the embedding models. See the documentation at https://localai.io/basics/getting_started/index.html and https://localai.io/features/embeddings/index.html."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "0be1af71",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings import LocalAIEmbeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "2c66e5da",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embeddings = LocalAIEmbeddings(openai_api_base=\"http://localhost:8080\", model=\"embedding-model-name\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "01370375",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"text = \"This is a test document.\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "bfb6142c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query_result = embeddings.embed_query(text)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "0356c3b7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"doc_result = embeddings.embed_documents([text])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bb61bbeb",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Let's load the LocalAI Embedding class with first generation models (e.g. text-search-ada-doc-001/text-search-ada-query-001). Note: These are not recommended models - see [here](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c0b072cc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings.openai import LocalAIEmbeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a56b70f5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embeddings = LocalAIEmbeddings(openai_api_base=\"http://localhost:8080\", model=\"embedding-model-name\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "14aefb64",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"text = \"This is a test document.\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "3c39ed33",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query_result = embeddings.embed_query(text)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e3221db6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"doc_result = embeddings.embed_documents([text])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "aaad49f8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# if you are behind an explicit proxy, you can use the OPENAI_PROXY environment variable to pass through\n",
|
||||||
|
"os.environ[\"OPENAI_PROXY\"] = \"http://proxy.yourcompany.com:8080\""
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3.11.1 64-bit",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.1"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "e971737741ff4ec9aff7dc6155a1060a59a8a6d52c757dbbe66bf8ee389494b1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -23,6 +23,7 @@ from langchain.embeddings.huggingface import (
|
|||||||
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
||||||
from langchain.embeddings.jina import JinaEmbeddings
|
from langchain.embeddings.jina import JinaEmbeddings
|
||||||
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
|
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
|
||||||
|
from langchain.embeddings.localai import LocalAIEmbeddings
|
||||||
from langchain.embeddings.minimax import MiniMaxEmbeddings
|
from langchain.embeddings.minimax import MiniMaxEmbeddings
|
||||||
from langchain.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
|
from langchain.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
|
||||||
from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings
|
from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings
|
||||||
@ -76,6 +77,7 @@ __all__ = [
|
|||||||
"SpacyEmbeddings",
|
"SpacyEmbeddings",
|
||||||
"NLPCloudEmbeddings",
|
"NLPCloudEmbeddings",
|
||||||
"GPT4AllEmbeddings",
|
"GPT4AllEmbeddings",
|
||||||
|
"LocalAIEmbeddings",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
345
libs/langchain/langchain/embeddings/localai.py
Normal file
345
libs/langchain/langchain/embeddings/localai.py
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
from tenacity import (
|
||||||
|
AsyncRetrying,
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
min_seconds = 4
|
||||||
|
max_seconds = 10
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(embeddings.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
min_seconds = 4
|
||||||
|
max_seconds = 10
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
async_retrying = AsyncRetrying(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(embeddings.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
def wrap(func: Callable) -> Callable:
|
||||||
|
async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
|
||||||
|
async for _ in async_retrying:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
raise AssertionError("this is unreachable")
|
||||||
|
|
||||||
|
return wrapped_f
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
|
# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
|
||||||
|
def _check_response(response: dict) -> dict:
|
||||||
|
if any(len(d["embedding"]) == 1 for d in response["data"]):
|
||||||
|
import openai
|
||||||
|
|
||||||
|
raise openai.error.APIError("LocalAI API returned an empty embedding")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the embedding call."""
|
||||||
|
retry_decorator = _create_retry_decorator(embeddings)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||||
|
response = embeddings.client.create(**kwargs)
|
||||||
|
return _check_response(response)
|
||||||
|
|
||||||
|
return _embed_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the embedding call."""
|
||||||
|
|
||||||
|
@_async_retry_decorator(embeddings)
|
||||||
|
async def _async_embed_with_retry(**kwargs: Any) -> Any:
|
||||||
|
response = await embeddings.client.acreate(**kwargs)
|
||||||
|
return _check_response(response)
|
||||||
|
|
||||||
|
return await _async_embed_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAIEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""LocalAI embedding models.
|
||||||
|
|
||||||
|
To use, you should have the ``openai`` python package installed, and the
|
||||||
|
environment variable ``OPENAI_API_KEY`` set to a random string. You need to
|
||||||
|
specify ``OPENAI_API_BASE`` to point to your LocalAI service endpoint.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.embeddings import LocalAIEmbeddings
|
||||||
|
openai = LocalAIEmbeddings(
|
||||||
|
openai_api_key="random-key",
|
||||||
|
openai_api_base="http://localhost:8080"
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
model: str = "text-embedding-ada-002"
|
||||||
|
deployment: str = model
|
||||||
|
openai_api_version: Optional[str] = None
|
||||||
|
openai_api_base: Optional[str] = None
|
||||||
|
# to support explicit proxy for LocalAI
|
||||||
|
openai_proxy: Optional[str] = None
|
||||||
|
embedding_ctx_length: int = 8191
|
||||||
|
"""The maximum number of tokens to embed at once."""
|
||||||
|
openai_api_key: Optional[str] = None
|
||||||
|
openai_organization: Optional[str] = None
|
||||||
|
allowed_special: Union[Literal["all"], Set[str]] = set()
|
||||||
|
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
|
||||||
|
chunk_size: int = 1000
|
||||||
|
"""Maximum number of texts to embed in each batch"""
|
||||||
|
max_retries: int = 6
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||||
|
"""Timeout in seconds for the LocalAI request."""
|
||||||
|
headers: Any = None
|
||||||
|
show_progress_bar: bool = False
|
||||||
|
"""Whether to show a progress bar when embedding."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name in extra:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
warnings.warn(
|
||||||
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
|
{field_name} was transferred to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
|
if invalid_model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
values["model_kwargs"] = extra
|
||||||
|
return values
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
values["openai_api_key"] = get_from_dict_or_env(
|
||||||
|
values, "openai_api_key", "OPENAI_API_KEY"
|
||||||
|
)
|
||||||
|
values["openai_api_base"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"openai_api_base",
|
||||||
|
"OPENAI_API_BASE",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
values["openai_proxy"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"openai_proxy",
|
||||||
|
"OPENAI_PROXY",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
default_api_version = ""
|
||||||
|
values["openai_api_version"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"openai_api_version",
|
||||||
|
"OPENAI_API_VERSION",
|
||||||
|
default=default_api_version,
|
||||||
|
)
|
||||||
|
values["openai_organization"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"openai_organization",
|
||||||
|
"OPENAI_ORGANIZATION",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
values["client"] = openai.Embedding
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import openai python package. "
|
||||||
|
"Please install it with `pip install openai`."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invocation_params(self) -> Dict:
|
||||||
|
openai_args = {
|
||||||
|
"model": self.model,
|
||||||
|
"request_timeout": self.request_timeout,
|
||||||
|
"headers": self.headers,
|
||||||
|
"api_key": self.openai_api_key,
|
||||||
|
"organization": self.openai_organization,
|
||||||
|
"api_base": self.openai_api_base,
|
||||||
|
"api_version": self.openai_api_version,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
if self.openai_proxy:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
openai.proxy = {
|
||||||
|
"http": self.openai_proxy,
|
||||||
|
"https": self.openai_proxy,
|
||||||
|
} # type: ignore[assignment] # noqa: E501
|
||||||
|
return openai_args
|
||||||
|
|
||||||
|
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint."""
|
||||||
|
# handle large input text
|
||||||
|
if self.model.endswith("001"):
|
||||||
|
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||||
|
# replace newlines, which can negatively affect performance.
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
return embed_with_retry(
|
||||||
|
self,
|
||||||
|
input=[text],
|
||||||
|
**self._invocation_params,
|
||||||
|
)["data"][
|
||||||
|
0
|
||||||
|
]["embedding"]
|
||||||
|
|
||||||
|
async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint."""
|
||||||
|
# handle large input text
|
||||||
|
if self.model.endswith("001"):
|
||||||
|
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
|
||||||
|
# replace newlines, which can negatively affect performance.
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
return (
|
||||||
|
await async_embed_with_retry(
|
||||||
|
self,
|
||||||
|
input=[text],
|
||||||
|
**self._invocation_params,
|
||||||
|
)
|
||||||
|
)["data"][0]["embedding"]
|
||||||
|
|
||||||
|
def embed_documents(
|
||||||
|
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint for embedding search docs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||||
|
specified by the class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
# call _embedding_func for each text
|
||||||
|
return [self._embedding_func(text, engine=self.deployment) for text in texts]
|
||||||
|
|
||||||
|
async def aembed_documents(
|
||||||
|
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint async for embedding search docs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
chunk_size: The chunk size of embeddings. If None, will use the chunk size
|
||||||
|
specified by the class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
response = await self._aembedding_func(text, engine=self.deployment)
|
||||||
|
embeddings.append(response)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint for embedding query text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding for the text.
|
||||||
|
"""
|
||||||
|
embedding = self._embedding_func(text, engine=self.deployment)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
|
"""Call out to LocalAI's embedding endpoint async for embedding query text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding for the text.
|
||||||
|
"""
|
||||||
|
embedding = await self._aembedding_func(text, engine=self.deployment)
|
||||||
|
return embedding
|
Loading…
Reference in New Issue
Block a user