From 3ec970cc110d10fb748d75210809f81a10879d98 Mon Sep 17 00:00:00 2001 From: David Duong Date: Tue, 3 Oct 2023 01:48:21 +0200 Subject: [PATCH] Mark Vertex AI classes as serialisable (#10484) --------- Co-authored-by: Erick Friis --- libs/langchain/langchain/chat_models/vertexai.py | 4 ++++ libs/langchain/langchain/llms/vertexai.py | 10 +++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/chat_models/vertexai.py b/libs/langchain/langchain/chat_models/vertexai.py index 3407620e69..91da917ea6 100644 --- a/libs/langchain/langchain/chat_models/vertexai.py +++ b/libs/langchain/langchain/chat_models/vertexai.py @@ -124,6 +124,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): model_name: str = "chat-bison" "Underlying model name." + @classmethod + def is_lc_serializable(self) -> bool: + return True + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in environment.""" diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index b0e6ea2dd3..c8625699c9 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -18,7 +18,7 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.llms.base import BaseLLM, create_base_retry_decorator -from langchain.pydantic_v1 import BaseModel, root_validator +from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.schema import ( Generation, LLMResult, @@ -144,7 +144,7 @@ class _VertexAIBase(BaseModel): "Default is 5." max_retries: int = 6 """The maximum number of retries to make when generating.""" - task_executor: ClassVar[Optional[Executor]] = None + task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True) stop: Optional[List[str]] = None "Optional list of stop words to use when generating." model_name: Optional[str] = None @@ -171,7 +171,7 @@ class _VertexAICommon(_VertexAIBase): top_k: int = 40 "How the model selects tokens for output, the next token is selected from " "among the top-k most probable tokens. Top-k is ignored for Codey models." - credentials: Any = None + credentials: Any = Field(default=None, exclude=True) "The default custom credentials (google.auth.credentials.Credentials) to use " "when making API calls. If not provided, credentials will be ascertained from " "the environment." @@ -229,6 +229,10 @@ class VertexAI(_VertexAICommon, BaseLLM): tuned_model_name: Optional[str] = None "The name of a tuned model. If provided, model_name is ignored." + @classmethod + def is_lc_serializable(self) -> bool: + return True + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in environment."""