From 265e650e64943fae2457871c93723a849badc20c Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 12 Jun 2024 13:59:05 -0400 Subject: [PATCH] community[patch]: Update root_validators embeddings: llamacpp, jina, dashscope, mosaicml, huggingface_hub, Toolkits: Connery, ChatModels: PAI_EAS, (#22828) This PR updates root validators for: * Embeddings: llamacpp, jina, dashscope, mosaicml, huggingface_hub * Toolkits: Connery * ChatModels: PAI_EAS Following this issue: https://github.com/langchain-ai/langchain/issues/22819 --- .../agent_toolkits/connery/toolkit.py | 2 +- .../chat_models/pai_eas_endpoint.py | 2 +- .../embeddings/dashscope.py | 2 +- .../embeddings/huggingface_hub.py | 27 +++++++++++-------- .../langchain_community/embeddings/jina.py | 2 +- .../embeddings/llamacpp.py | 2 +- .../embeddings/mosaicml.py | 2 +- 7 files changed, 22 insertions(+), 17 deletions(-) diff --git a/libs/community/langchain_community/agent_toolkits/connery/toolkit.py b/libs/community/langchain_community/agent_toolkits/connery/toolkit.py index b48b16ac93..05a8a68706 100644 --- a/libs/community/langchain_community/agent_toolkits/connery/toolkit.py +++ b/libs/community/langchain_community/agent_toolkits/connery/toolkit.py @@ -19,7 +19,7 @@ class ConneryToolkit(BaseToolkit): """ return self.tools - @root_validator() + @root_validator(pre=True) def validate_attributes(cls, values: dict) -> dict: """ Validate the attributes of the ConneryToolkit class. diff --git a/libs/community/langchain_community/chat_models/pai_eas_endpoint.py b/libs/community/langchain_community/chat_models/pai_eas_endpoint.py index e438ad25ee..48c7eb9b42 100644 --- a/libs/community/langchain_community/chat_models/pai_eas_endpoint.py +++ b/libs/community/langchain_community/chat_models/pai_eas_endpoint.py @@ -67,7 +67,7 @@ class PaiEasChatEndpoint(BaseChatModel): timeout: Optional[int] = 5000 - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["eas_service_url"] = get_from_dict_or_env( diff --git a/libs/community/langchain_community/embeddings/dashscope.py b/libs/community/langchain_community/embeddings/dashscope.py index 0042c05deb..1bdf72ba7e 100644 --- a/libs/community/langchain_community/embeddings/dashscope.py +++ b/libs/community/langchain_community/embeddings/dashscope.py @@ -110,7 +110,7 @@ class DashScopeEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: import dashscope diff --git a/libs/community/langchain_community/embeddings/huggingface_hub.py b/libs/community/langchain_community/embeddings/huggingface_hub.py index 3600a9a4aa..0ee0881d2e 100644 --- a/libs/community/langchain_community/embeddings/huggingface_hub.py +++ b/libs/community/langchain_community/embeddings/huggingface_hub.py @@ -1,10 +1,10 @@ import json -import os from typing import Any, Dict, List, Optional from langchain_core._api import deprecated from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.utils import get_from_dict_or_env DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2" VALID_TASKS = ("feature-extraction",) @@ -52,19 +52,19 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - huggingfacehub_api_token = values["huggingfacehub_api_token"] or os.getenv( - "HUGGINGFACEHUB_API_TOKEN" + huggingfacehub_api_token = get_from_dict_or_env( + values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" ) try: from huggingface_hub import AsyncInferenceClient, InferenceClient - if values["model"]: + if values.get("model"): values["repo_id"] = values["model"] - elif values["repo_id"]: + elif values.get("repo_id"): values["model"] = values["repo_id"] else: values["model"] = DEFAULT_MODEL @@ -80,11 +80,6 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings): token=huggingfacehub_api_token, ) - if values["task"] not in VALID_TASKS: - raise ValueError( - f"Got invalid task {values['task']}, " - f"currently only {VALID_TASKS} are supported" - ) values["client"] = client values["async_client"] = async_client @@ -95,6 +90,16 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings): ) return values + @root_validator(pre=False, skip_on_failure=True) + def post_init(cls, values: Dict) -> Dict: + """Post init validation for the class.""" + if values["task"] not in VALID_TASKS: + raise ValueError( + f"Got invalid task {values['task']}, " + f"currently only {VALID_TASKS} are supported" + ) + return values + def embed_documents(self, texts: List[str]) -> List[List[float]]: """Call out to HuggingFaceHub's embedding endpoint for embedding search docs. diff --git a/libs/community/langchain_community/embeddings/jina.py b/libs/community/langchain_community/embeddings/jina.py index 7c50faf46b..d62a36924f 100644 --- a/libs/community/langchain_community/embeddings/jina.py +++ b/libs/community/langchain_community/embeddings/jina.py @@ -30,7 +30,7 @@ class JinaEmbeddings(BaseModel, Embeddings): model_name: str = "jina-embeddings-v2-base-en" jina_api_key: Optional[SecretStr] = None - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that auth token exists in environment.""" try: diff --git a/libs/community/langchain_community/embeddings/llamacpp.py b/libs/community/langchain_community/embeddings/llamacpp.py index fcf7367597..43293e33ce 100644 --- a/libs/community/langchain_community/embeddings/llamacpp.py +++ b/libs/community/langchain_community/embeddings/llamacpp.py @@ -62,7 +62,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that llama-cpp-python library is installed.""" model_path = values["model_path"] diff --git a/libs/community/langchain_community/embeddings/mosaicml.py b/libs/community/langchain_community/embeddings/mosaicml.py index bd8d97bb7e..69b8d1ba3b 100644 --- a/libs/community/langchain_community/embeddings/mosaicml.py +++ b/libs/community/langchain_community/embeddings/mosaicml.py @@ -46,7 +46,7 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings): extra = Extra.forbid - @root_validator() + @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" mosaicml_api_token = get_from_dict_or_env(