From b150a0e504d6d38563903d539b2bd1091e702194 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 24 Nov 2022 06:43:24 -0800 Subject: [PATCH] flexible model args --- langchain/llms/openai.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 2affb86d..8bd857c0 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -1,7 +1,7 @@ """Wrapper around OpenAI APIs.""" from typing import Any, Dict, List, Mapping, Optional -from pydantic import BaseModel, Extra, root_validator +from pydantic import BaseModel, Extra, Field, root_validator from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -37,6 +37,7 @@ class OpenAI(LLM, BaseModel): """How many completions to generate for each prompt.""" best_of: int = 1 """Generates best_of completions server-side and returns the "best".""" + model_kwargs: dict = Field(default_factory=dict) openai_api_key: Optional[str] = None @@ -63,10 +64,29 @@ class OpenAI(LLM, BaseModel): ) return values + @root_validator() + def validate_model_kwargs(cls, values: Dict) -> Dict: + named_params = { + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "n", + "best_of", + } + overlap = named_params.intersection(values["model_kwargs"]) + if overlap: + raise ValueError( + "Found named params in model_kwargs, " + f"should be specified separately: {overlap}" + ) + return values + @property def _default_params(self) -> Mapping[str, Any]: """Get the default parameters for calling OpenAI API.""" - return { + named_params = { "temperature": self.temperature, "max_tokens": self.max_tokens, "top_p": self.top_p, @@ -75,6 +95,7 @@ class OpenAI(LLM, BaseModel): "n": self.n, "best_of": self.best_of, } + return {**named_params, **self.model_kwargs} @property def _identifying_params(self) -> Mapping[str, Any]: