You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/embeddings/volcengine.py

129 lines
4.1 KiB
Python

from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class VolcanoEmbeddings(BaseModel, Embeddings):
"""`Volcengine Embeddings` embedding models."""
volcano_ak: Optional[str] = None
"""volcano access key
learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk"""
volcano_sk: Optional[str] = None
"""volcano secret key
learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk"""
host: str = "maas-api.ml-platform-cn-beijing.volces.com"
"""host
learn more from https://www.volcengine.com/docs/82379/1174746"""
region: str = "cn-beijing"
"""region
learn more from https://www.volcengine.com/docs/82379/1174746"""
model: str = "bge-large-zh"
"""Model name
you could get from https://www.volcengine.com/docs/82379/1174746
for now, we support bge_large_zh
"""
version: str = "1.0"
""" model version """
chunk_size: int = 100
"""Chunk size when multiple texts are input"""
client: Any
"""volcano client"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
Validate whether volcano_ak and volcano_sk in the environment variables or
configuration file are available or not.
init volcano embedding client with `ak`, `sk`, `host`, `region`
Args:
values: a dictionary containing configuration information, must include the
fields of volcano_ak and volcano_sk
Returns:
a dictionary containing configuration information. If volcano_ak and
volcano_sk are not provided in the environment variables or configuration
file,the original values will be returned; otherwise, values containing
volcano_ak and volcano_sk will be returned.
Raises:
ValueError: volcengine package not found, please install it with
`pip install volcengine`
"""
values["volcano_ak"] = get_from_dict_or_env(
values,
"volcano_ak",
"VOLC_ACCESSKEY",
)
values["volcano_sk"] = get_from_dict_or_env(
values,
"volcano_sk",
"VOLC_SECRETKEY",
)
try:
from volcengine.maas import MaasService
client = MaasService(values["host"], values["region"])
client.set_ak(values["volcano_ak"])
client.set_sk(values["volcano_sk"])
values["client"] = client
except ImportError:
raise ImportError(
"volcengine package not found, please install it with "
"`pip install volcengine`"
)
return values
def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of text documents using the AutoVOT algorithm.
Args:
texts (List[str]): A list of text documents to embed.
Returns:
List[List[float]]: A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
req = {
"model": {
"name": self.model,
"version": self.version,
},
"input": chunk,
}
try:
from volcengine.maas import MaasException
resp = self.client.embeddings(req)
lst.extend([res["embedding"] for res in resp["data"]])
except MaasException as e:
raise ValueError(f"embed by volcengine Error: {e}")
return lst