|
|
|
@ -2,6 +2,7 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
|
|
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
|
|
@ -59,6 +60,8 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
"""The url of the API."""
|
|
|
|
|
max_retries: int = 6
|
|
|
|
|
"""Maximum number of retries to make when generating."""
|
|
|
|
|
sleep_interval: float = 0.0
|
|
|
|
|
"""Delay between API requests"""
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
@ -154,7 +157,8 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str]):
|
|
|
|
|
)
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
|
|
|
|
"Please install YandexCloud SDK with `pip install yandexcloud` \
|
|
|
|
|
or upgrade it to recent version."
|
|
|
|
|
) from e
|
|
|
|
|
result = []
|
|
|
|
|
channel_credentials = grpc.ssl_channel_credentials()
|
|
|
|
@ -164,6 +168,7 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str]):
|
|
|
|
|
request = TextEmbeddingRequest(model_uri=self.model_uri, text=text)
|
|
|
|
|
stub = EmbeddingsServiceStub(channel)
|
|
|
|
|
res = stub.TextEmbedding(request, metadata=self._grpc_metadata)
|
|
|
|
|
result.append(res.embedding)
|
|
|
|
|
result.append(list(res.embedding))
|
|
|
|
|
time.sleep(self.sleep_interval)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|