mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
groq[patch]: Update root validators for pydantic 2 migration (#25402)
This commit is contained in:
parent
8eb63a609e
commit
d72a08a60d
@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
@ -75,9 +74,9 @@ from langchain_core.pydantic_v1 import (
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
secret_from_env,
|
||||
)
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
@ -308,13 +307,19 @@ class ChatGroq(BaseChatModel):
|
||||
"""Default stop sequences."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
groq_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
groq_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key", default_factory=secret_from_env("GROQ_API_KEY", default=None)
|
||||
)
|
||||
"""Automatically inferred from env var `GROQ_API_KEY` if not provided."""
|
||||
groq_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
groq_api_base: Optional[str] = Field(
|
||||
alias="base_url", default_factory=from_env("GROQ_API_BASE", default=None)
|
||||
)
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
# to support explicit proxy for Groq
|
||||
groq_proxy: Optional[str] = None
|
||||
groq_proxy: Optional[str] = Field(
|
||||
default_factory=from_env("GROQ_PROXY", default=None)
|
||||
)
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
@ -369,25 +374,20 @@ class ChatGroq(BaseChatModel):
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
if values["temperature"] == 0:
|
||||
values["temperature"] = 1e-8
|
||||
|
||||
values["groq_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "groq_api_key", "GROQ_API_KEY")
|
||||
)
|
||||
values["groq_api_base"] = values["groq_api_base"] or os.getenv("GROQ_API_BASE")
|
||||
values["groq_proxy"] = values["groq_proxy"] = os.getenv("GROQ_PROXY")
|
||||
|
||||
client_params = {
|
||||
"api_key": values["groq_api_key"].get_secret_value(),
|
||||
"api_key": values["groq_api_key"].get_secret_value()
|
||||
if values["groq_api_key"]
|
||||
else None,
|
||||
"base_url": values["groq_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
|
Loading…
Reference in New Issue
Block a user