groq[patch]: Update root validators for pydantic 2 migration (#25402)

This commit is contained in:
Eugene Yurtsev 2024-08-15 14:46:52 -04:00 committed by GitHub
parent 8eb63a609e
commit d72a08a60d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import json
import os
import warnings
from operator import itemgetter
from typing import (
@ -75,9 +74,9 @@ from langchain_core.pydantic_v1 import (
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
from_env,
get_pydantic_field_names,
secret_from_env,
)
from langchain_core.utils.function_calling import (
convert_to_openai_function,
@ -308,13 +307,19 @@ class ChatGroq(BaseChatModel):
"""Default stop sequences."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
groq_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
groq_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("GROQ_API_KEY", default=None)
)
"""Automatically inferred from env var `GROQ_API_KEY` if not provided."""
groq_api_base: Optional[str] = Field(default=None, alias="base_url")
groq_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("GROQ_API_BASE", default=None)
)
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
# to support explicit proxy for Groq
groq_proxy: Optional[str] = None
groq_proxy: Optional[str] = Field(
default_factory=from_env("GROQ_PROXY", default=None)
)
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
@ -369,25 +374,20 @@ class ChatGroq(BaseChatModel):
values["model_kwargs"] = extra
return values
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
if values["temperature"] == 0:
values["temperature"] = 1e-8
values["groq_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "groq_api_key", "GROQ_API_KEY")
)
values["groq_api_base"] = values["groq_api_base"] or os.getenv("GROQ_API_BASE")
values["groq_proxy"] = values["groq_proxy"] = os.getenv("GROQ_PROXY")
client_params = {
"api_key": values["groq_api_key"].get_secret_value(),
"api_key": values["groq_api_key"].get_secret_value()
if values["groq_api_key"]
else None,
"base_url": values["groq_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],