diff --git a/libs/partners/google-genai/langchain_google_genai/chat_models.py b/libs/partners/google-genai/langchain_google_genai/chat_models.py index b1f289095f..d226aeefc4 100644 --- a/libs/partners/google-genai/langchain_google_genai/chat_models.py +++ b/libs/partners/google-genai/langchain_google_genai/chat_models.py @@ -466,6 +466,14 @@ Supported examples: Gemini does not support system messages; any unsupported messages will raise an error.""" + client_options: Optional[Dict] = Field( + None, + description="Client options to pass to the Google API client.", + ) + transport: Optional[str] = Field( + None, + description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].", + ) class Config: allow_population_by_field_name = True @@ -488,12 +496,18 @@ Supported examples: @root_validator() def validate_environment(cls, values: Dict) -> Dict: + """Validates params and passes them to google-generativeai package.""" google_api_key = get_from_dict_or_env( values, "google_api_key", "GOOGLE_API_KEY" ) if isinstance(google_api_key, SecretStr): google_api_key = google_api_key.get_secret_value() - genai.configure(api_key=google_api_key) + + genai.configure( + api_key=google_api_key, + transport=values.get("transport"), + client_options=values.get("client_options"), + ) if ( values.get("temperature") is not None and not 0 <= values["temperature"] <= 1 diff --git a/libs/partners/google-genai/langchain_google_genai/embeddings.py b/libs/partners/google-genai/langchain_google_genai/embeddings.py index 0b581265fe..5e61581e01 100644 --- a/libs/partners/google-genai/langchain_google_genai/embeddings.py +++ b/libs/partners/google-genai/langchain_google_genai/embeddings.py @@ -43,16 +43,32 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings): description="The Google API key to use. If not provided, " "the GOOGLE_API_KEY environment variable will be used.", ) + client_options: Optional[Dict] = Field( + None, + description=( + "A dictionary of client options to pass to the Google API client, " + "such as `api_endpoint`." + ), + ) + transport: Optional[str] = Field( + None, + description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].", + ) @root_validator() def validate_environment(cls, values: Dict) -> Dict: - """Validates that the python package exists in environment.""" + """Validates params and passes them to google-generativeai package.""" google_api_key = get_from_dict_or_env( values, "google_api_key", "GOOGLE_API_KEY" ) if isinstance(google_api_key, SecretStr): google_api_key = google_api_key.get_secret_value() - genai.configure(api_key=google_api_key) + + genai.configure( + api_key=google_api_key, + transport=values.get("transport"), + client_options=values.get("client_options"), + ) return values def _embed( diff --git a/libs/partners/google-genai/langchain_google_genai/llms.py b/libs/partners/google-genai/langchain_google_genai/llms.py index fd57a7acab..ec9ec4d67f 100644 --- a/libs/partners/google-genai/langchain_google_genai/llms.py +++ b/libs/partners/google-genai/langchain_google_genai/llms.py @@ -121,6 +121,17 @@ Supported examples: not return the full n completions if duplicates are generated.""" max_retries: int = 6 """The maximum number of retries to make when generating.""" + client_options: Optional[Dict] = Field( + None, + description=( + "A dictionary of client options to pass to the Google API client, " + "such as `api_endpoint`." + ), + ) + transport: Optional[str] = Field( + None, + description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].", + ) @property def is_gemini(self) -> bool: @@ -133,7 +144,7 @@ Supported examples: @root_validator() def validate_environment(cls, values: Dict) -> Dict: - """Validate api key, python package exists.""" + """Validates params and passes them to google-generativeai package.""" google_api_key = get_from_dict_or_env( values, "google_api_key", "GOOGLE_API_KEY" ) @@ -142,7 +153,11 @@ Supported examples: if isinstance(google_api_key, SecretStr): google_api_key = google_api_key.get_secret_value() - genai.configure(api_key=google_api_key) + genai.configure( + api_key=google_api_key, + transport=values.get("transport"), + client_options=values.get("client_options"), + ) if _is_gemini_model(model_name): values["client"] = genai.GenerativeModel(model_name=model_name)