Check OpenAI model kwargs (#4366)

Handle duplicate and incorrectly specified OpenAI params

Thanks @PawelFaron for the fix! Made small update

Closes #4331

---------

Co-authored-by: PawelFaron <42373772+PawelFaron@users.noreply.github.com>
Co-authored-by: Pawel Faron <ext-pawel.faron@vaisala.com>
parallel_dir_loader
Davis Chase 1 year ago committed by GitHub
parent 02ebb15c4a
commit ba0057c077
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -143,10 +143,24 @@ class ChatOpenAI(BaseChatModel):
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
disallowed_model_kwargs = all_required_field_names | {"model"}
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values

@ -186,15 +186,24 @@ class BaseOpenAI(BaseLLM):
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transfered to model_kwargs.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
disallowed_model_kwargs = all_required_field_names | {"model"}
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@ -422,7 +431,7 @@ class BaseOpenAI(BaseLLM):
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""Prepare the params for streaming."""
params = self._invocation_params
if params["best_of"] != 1:
if "best_of" in params and params["best_of"] != 1:
raise ValueError("OpenAI only supports best_of == 1 for streaming")
if stop is not None:
if "stop" in params:

@ -147,3 +147,27 @@ async def test_async_chat_openai_streaming() -> None:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
def test_chat_openai_extra_kwargs() -> None:
"""Test extra kwargs to chat openai."""
# Check that foo is saved in extra_kwargs.
llm = ChatOpenAI(foo=3, max_tokens=10)
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
# Test that if extra_kwargs are provided, they are added to it.
llm = ChatOpenAI(foo=3, model_kwargs={"bar": 2})
assert llm.model_kwargs == {"foo": 3, "bar": 2}
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatOpenAI(foo=3, model_kwargs={"foo": 2})
# Test that if explicit param is specified in kwargs it errors
with pytest.raises(ValueError):
ChatOpenAI(model_kwargs={"temperature": 0.2})
# Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError):
ChatOpenAI(model_kwargs={"model": "text-davinci-003"})

@ -34,6 +34,14 @@ def test_openai_extra_kwargs() -> None:
with pytest.raises(ValueError):
OpenAI(foo=3, model_kwargs={"foo": 2})
# Test that if explicit param is specified in kwargs it errors
with pytest.raises(ValueError):
OpenAI(model_kwargs={"temperature": 0.2})
# Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError):
OpenAI(model_kwargs={"model": "text-davinci-003"})
def test_openai_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""

Loading…
Cancel
Save