diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 1b035fba..a4d2ac87 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -143,10 +143,24 @@ class ChatOpenAI(BaseChatModel): extra = values.get("model_kwargs", {}) for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: - if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) extra[field_name] = values.pop(field_name) + + disallowed_model_kwargs = all_required_field_names | {"model"} + invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + values["model_kwargs"] = extra return values diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index b29cbf60..bfca7fe3 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -186,15 +186,24 @@ class BaseOpenAI(BaseLLM): extra = values.get("model_kwargs", {}) for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: - if field_name in extra: - raise ValueError(f"Found {field_name} supplied twice.") logger.warning( f"""WARNING! {field_name} is not default parameter. - {field_name} was transfered to model_kwargs. + {field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""" ) extra[field_name] = values.pop(field_name) + + disallowed_model_kwargs = all_required_field_names | {"model"} + invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + values["model_kwargs"] = extra return values @@ -422,7 +431,7 @@ class BaseOpenAI(BaseLLM): def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]: """Prepare the params for streaming.""" params = self._invocation_params - if params["best_of"] != 1: + if "best_of" in params and params["best_of"] != 1: raise ValueError("OpenAI only supports best_of == 1 for streaming") if stop is not None: if "stop" in params: diff --git a/tests/integration_tests/chat_models/test_openai.py b/tests/integration_tests/chat_models/test_openai.py index 432c2c88..679cb06c 100644 --- a/tests/integration_tests/chat_models/test_openai.py +++ b/tests/integration_tests/chat_models/test_openai.py @@ -147,3 +147,27 @@ async def test_async_chat_openai_streaming() -> None: assert isinstance(generation, ChatGeneration) assert isinstance(generation.text, str) assert generation.text == generation.message.content + + +def test_chat_openai_extra_kwargs() -> None: + """Test extra kwargs to chat openai.""" + # Check that foo is saved in extra_kwargs. + llm = ChatOpenAI(foo=3, max_tokens=10) + assert llm.max_tokens == 10 + assert llm.model_kwargs == {"foo": 3} + + # Test that if extra_kwargs are provided, they are added to it. + llm = ChatOpenAI(foo=3, model_kwargs={"bar": 2}) + assert llm.model_kwargs == {"foo": 3, "bar": 2} + + # Test that if provided twice it errors + with pytest.raises(ValueError): + ChatOpenAI(foo=3, model_kwargs={"foo": 2}) + + # Test that if explicit param is specified in kwargs it errors + with pytest.raises(ValueError): + ChatOpenAI(model_kwargs={"temperature": 0.2}) + + # Test that "model" cannot be specified in kwargs + with pytest.raises(ValueError): + ChatOpenAI(model_kwargs={"model": "text-davinci-003"}) diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index e10a9c0b..5132e3bb 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -34,6 +34,14 @@ def test_openai_extra_kwargs() -> None: with pytest.raises(ValueError): OpenAI(foo=3, model_kwargs={"foo": 2}) + # Test that if explicit param is specified in kwargs it errors + with pytest.raises(ValueError): + OpenAI(model_kwargs={"temperature": 0.2}) + + # Test that "model" cannot be specified in kwargs + with pytest.raises(ValueError): + OpenAI(model_kwargs={"model": "text-davinci-003"}) + def test_openai_llm_output_contains_model_name() -> None: """Test llm_output contains model_name."""