Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase
b150a0e504 flexible model args 2022-11-24 06:43:24 -08:00

View File

@ -1,7 +1,7 @@
"""Wrapper around OpenAI APIs."""
from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
@ -37,6 +37,7 @@ class OpenAI(LLM, BaseModel):
"""How many completions to generate for each prompt."""
best_of: int = 1
"""Generates best_of completions server-side and returns the "best"."""
model_kwargs: dict = Field(default_factory=dict)
openai_api_key: Optional[str] = None
@ -63,10 +64,29 @@ class OpenAI(LLM, BaseModel):
)
return values
@root_validator()
def validate_model_kwargs(cls, values: Dict) -> Dict:
named_params = {
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"n",
"best_of",
}
overlap = named_params.intersection(values["model_kwargs"])
if overlap:
raise ValueError(
"Found named params in model_kwargs, "
f"should be specified separately: {overlap}"
)
return values
@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
named_params = {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
@ -75,6 +95,7 @@ class OpenAI(LLM, BaseModel):
"n": self.n,
"best_of": self.best_of,
}
return {**named_params, **self.model_kwargs}
@property
def _identifying_params(self) -> Mapping[str, Any]: