From c8391d4ff16d9fce6d481e5468cd6fb0d475804f Mon Sep 17 00:00:00 2001 From: Egor Krasheninnikov Date: Sun, 14 Apr 2024 00:23:01 +0100 Subject: [PATCH] community[patch]: Fix YandexGPT embeddings (#19720) Fix of YandexGPT embeddings. The current version uses a single `model_name` for queries and documents, essentially making the `embed_documents` and `embed_query` methods the same. Yandex has a different endpoint (`model_uri`) for encoding documents, see [this](https://yandex.cloud/en/docs/yandexgpt/concepts/embeddings). The bug may impact retrievers built with `YandexGPTEmbeddings` (for instance FAISS database as retriever) since they use both `embed_documents` and `embed_query`. A simple snippet to test the behaviour: ```python from langchain_community.embeddings.yandex import YandexGPTEmbeddings embeddings = YandexGPTEmbeddings() q_emb = embeddings.embed_query('hello world') doc_emb = embeddings.embed_documents(['hello world', 'hello world']) q_emb == doc_emb[0] ``` The response is `True` with the current version and `False` with the changes I made. Twitter: @egor_krash --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur --- .../langchain_community/embeddings/yandex.py | 50 +++++++++++++------ .../unit_tests/embeddings/test_yandex.py | 24 +++++++++ 2 files changed, 59 insertions(+), 15 deletions(-) create mode 100644 libs/community/tests/unit_tests/embeddings/test_yandex.py diff --git a/libs/community/langchain_community/embeddings/yandex.py b/libs/community/langchain_community/embeddings/yandex.py index 4183a3284c..c129c8169e 100644 --- a/libs/community/langchain_community/embeddings/yandex.py +++ b/libs/community/langchain_community/embeddings/yandex.py @@ -6,7 +6,7 @@ import time from typing import Any, Callable, Dict, List from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from tenacity import ( before_sleep_log, @@ -33,14 +33,13 @@ class YandexGPTEmbeddings(BaseModel, Embeddings): To use the default model specify the folder ID in a parameter `folder_id` or in an environment variable `YC_FOLDER_ID`. - Or specify the model URI in a constructor parameter `model_uri` Example: .. code-block:: python from langchain_community.embeddings.yandex import YandexGPTEmbeddings - embeddings = YandexGPTEmbeddings(iam_token="t1.9eu...", model_uri="emb:///text-search-query/latest") - """ + embeddings = YandexGPTEmbeddings(iam_token="t1.9eu...", folder_id=) + """ # noqa: E501 iam_token: SecretStr = "" # type: ignore[assignment] """Yandex Cloud IAM token for service account @@ -48,12 +47,16 @@ class YandexGPTEmbeddings(BaseModel, Embeddings): api_key: SecretStr = "" # type: ignore[assignment] """Yandex Cloud Api Key for service account with the `ai.languageModels.user` role""" - model_uri: str = "" - """Model uri to use.""" + model_uri: str = Field(default="", alias="query_model_uri") + """Query model uri to use.""" + doc_model_uri: str = "" + """Doc model uri to use.""" folder_id: str = "" """Yandex Cloud folder ID""" - model_name: str = "text-search-query" - """Model name to use.""" + doc_model_name: str = "text-search-doc" + """Doc model name to use.""" + model_name: str = Field(default="text-search-query", alias="query_model_name") + """Query model name to use.""" model_version: str = "latest" """Model version to use.""" url: str = "llm.api.cloud.yandex.net:443" @@ -63,6 +66,11 @@ class YandexGPTEmbeddings(BaseModel, Embeddings): sleep_interval: float = 0.0 """Delay between API requests""" + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that iam token exists in environment.""" @@ -89,12 +97,19 @@ class YandexGPTEmbeddings(BaseModel, Embeddings): values["_grpc_metadata"] = ( ("authorization", f"Api-Key {values['api_key'].get_secret_value()}"), ) - if values["model_uri"] == "" and values["folder_id"] == "": - raise ValueError("Either 'model_uri' or 'folder_id' must be provided.") - if not values["model_uri"]: + + if not values.get("doc_model_uri"): + if values["folder_id"] == "": + raise ValueError("'doc_model_uri' or 'folder_id' must be provided.") + values[ + "doc_model_uri" + ] = f"emb://{values['folder_id']}/{values['doc_model_name']}/{values['model_version']}" # noqa: E501 + if not values.get("model_uri"): + if values["folder_id"] == "": + raise ValueError("'model_uri' or 'folder_id' must be provided.") values[ "model_uri" - ] = f"emb://{values['folder_id']}/{values['model_name']}/{values['model_version']}" + ] = f"emb://{values['folder_id']}/{values['model_name']}/{values['model_version']}" # noqa: E501 return values def embed_documents(self, texts: List[str]) -> List[List[float]]: @@ -118,7 +133,7 @@ class YandexGPTEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - return _embed_with_retry(self, texts=[text])[0] + return _embed_with_retry(self, texts=[text], embed_query=True)[0] def _create_retry_decorator(llm: YandexGPTEmbeddings) -> Callable[[Any], Any]: @@ -146,7 +161,7 @@ def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> Any: return _completion_with_retry(**kwargs) -def _make_request(self: YandexGPTEmbeddings, texts: List[str]): # type: ignore[no-untyped-def] +def _make_request(self: YandexGPTEmbeddings, texts: List[str], **kwargs): # type: ignore[no-untyped-def] try: import grpc @@ -172,9 +187,14 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str]): # type: ignore[ result = [] channel_credentials = grpc.ssl_channel_credentials() channel = grpc.secure_channel(self.url, channel_credentials) + # Use the query model if embed_query is True + if kwargs.get("embed_query"): + model_uri = self.model_uri + else: + model_uri = self.doc_model_uri for text in texts: - request = TextEmbeddingRequest(model_uri=self.model_uri, text=text) + request = TextEmbeddingRequest(model_uri=model_uri, text=text) stub = EmbeddingsServiceStub(channel) res = stub.TextEmbedding(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] result.append(list(res.embedding)) diff --git a/libs/community/tests/unit_tests/embeddings/test_yandex.py b/libs/community/tests/unit_tests/embeddings/test_yandex.py new file mode 100644 index 0000000000..2593927799 --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_yandex.py @@ -0,0 +1,24 @@ +import os + +from langchain_community.embeddings import YandexGPTEmbeddings + + +def test_init() -> None: + os.environ["YC_API_KEY"] = "foo" + models = [ + YandexGPTEmbeddings(folder_id="bar"), + YandexGPTEmbeddings( + query_model_uri="emb://bar/text-search-query/latest", + doc_model_uri="emb://bar/text-search-doc/latest", + ), + YandexGPTEmbeddings( + folder_id="bar", + query_model_name="text-search-query", + doc_model_name="text-search-doc", + ), + ] + for embeddings in models: + assert embeddings.model_uri == "emb://bar/text-search-query/latest" + assert embeddings.doc_model_uri == "emb://bar/text-search-doc/latest" + assert embeddings.model_name == "text-search-query" + assert embeddings.doc_model_name == "text-search-doc"