|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
|
from typing import Any, Callable, Dict, List, Sequence
|
|
|
|
|
|
|
|
|
|
from langchain_core.embeddings import Embeddings
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
|
|
|
@ -65,6 +65,10 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
"""Maximum number of retries to make when generating."""
|
|
|
|
|
sleep_interval: float = 0.0
|
|
|
|
|
"""Delay between API requests"""
|
|
|
|
|
disable_request_logging: bool = False
|
|
|
|
|
"""YandexGPT API logs all request data by default.
|
|
|
|
|
If you provide personal data, confidential information, disable logging."""
|
|
|
|
|
_grpc_metadata: Sequence
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
@ -110,6 +114,13 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
values[
|
|
|
|
|
"model_uri"
|
|
|
|
|
] = f"emb://{values['folder_id']}/{values['model_name']}/{values['model_version']}" # noqa: E501
|
|
|
|
|
if values["disable_request_logging"]:
|
|
|
|
|
values["_grpc_metadata"].append(
|
|
|
|
|
(
|
|
|
|
|
"x-data-logging-enabled",
|
|
|
|
|
"false",
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
|