|
|
|
@ -124,7 +124,7 @@ class BaseOpenAI(BaseLLM):
|
|
|
|
|
"""Wrapper around OpenAI large language models."""
|
|
|
|
|
|
|
|
|
|
client: Any #: :meta private:
|
|
|
|
|
model_name: str = "text-davinci-003"
|
|
|
|
|
model_name: str = Field("text-davinci-003", alias="model")
|
|
|
|
|
"""Model name to use."""
|
|
|
|
|
temperature: float = 0.7
|
|
|
|
|
"""What sampling temperature to use."""
|
|
|
|
@ -178,12 +178,12 @@ class BaseOpenAI(BaseLLM):
|
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
|
|
|
|
|
|
extra = Extra.ignore
|
|
|
|
|
allow_population_by_field_name = True
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
"""Build extra kwargs from additional params that were passed in."""
|
|
|
|
|
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
|
|
|
|
|
|
|
|
|
all_required_field_names = cls.all_required_field_names()
|
|
|
|
|
extra = values.get("model_kwargs", {})
|
|
|
|
|
for field_name in list(values):
|
|
|
|
|
if field_name in extra:
|
|
|
|
@ -196,8 +196,7 @@ class BaseOpenAI(BaseLLM):
|
|
|
|
|
)
|
|
|
|
|
extra[field_name] = values.pop(field_name)
|
|
|
|
|
|
|
|
|
|
disallowed_model_kwargs = all_required_field_names | {"model"}
|
|
|
|
|
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
|
|
|
|
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
|
|
|
|
if invalid_model_kwargs:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
|
|
|
|