|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
"""Wrapper around OpenAI APIs."""
|
|
|
|
|
from typing import Any, Dict, List, Mapping, Optional
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Extra, root_validator
|
|
|
|
|
from pydantic import BaseModel, Extra, Field, root_validator
|
|
|
|
|
|
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
|
from langchain.utils import get_from_dict_or_env
|
|
|
|
@ -37,6 +37,7 @@ class OpenAI(LLM, BaseModel):
|
|
|
|
|
"""How many completions to generate for each prompt."""
|
|
|
|
|
best_of: int = 1
|
|
|
|
|
"""Generates best_of completions server-side and returns the "best"."""
|
|
|
|
|
model_kwargs: dict = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
openai_api_key: Optional[str] = None
|
|
|
|
|
|
|
|
|
@ -63,10 +64,29 @@ class OpenAI(LLM, BaseModel):
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_model_kwargs(cls, values: Dict) -> Dict:
|
|
|
|
|
named_params = {
|
|
|
|
|
"temperature",
|
|
|
|
|
"max_tokens",
|
|
|
|
|
"top_p",
|
|
|
|
|
"frequency_penalty",
|
|
|
|
|
"presence_penalty",
|
|
|
|
|
"n",
|
|
|
|
|
"best_of",
|
|
|
|
|
}
|
|
|
|
|
overlap = named_params.intersection(values["model_kwargs"])
|
|
|
|
|
if overlap:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Found named params in model_kwargs, "
|
|
|
|
|
f"should be specified separately: {overlap}"
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _default_params(self) -> Mapping[str, Any]:
|
|
|
|
|
"""Get the default parameters for calling OpenAI API."""
|
|
|
|
|
return {
|
|
|
|
|
named_params = {
|
|
|
|
|
"temperature": self.temperature,
|
|
|
|
|
"max_tokens": self.max_tokens,
|
|
|
|
|
"top_p": self.top_p,
|
|
|
|
@ -75,6 +95,7 @@ class OpenAI(LLM, BaseModel):
|
|
|
|
|
"n": self.n,
|
|
|
|
|
"best_of": self.best_of,
|
|
|
|
|
}
|
|
|
|
|
return {**named_params, **self.model_kwargs}
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
|
|