Fix replicate model kwargs (#10599)

This commit is contained in:
Bagatur 2023-09-14 14:43:42 -07:00 committed by GitHub
parent 77a165e0d9
commit e195b78e1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 6 deletions

View File

@ -23,7 +23,7 @@ class Replicate(LLM):
You can find your token here: https://replicate.com/account
The model param is required, but any other model parameters can also
be passed in with the format input={model_param: value, ...}
be passed in with the format model_kwargs={model_param: value, ...}
Example:
.. code-block:: python
@ -35,13 +35,12 @@ class Replicate(LLM):
"stability-ai/stable-diffusion: "
"27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
),
input={"image_dimensions": "512x512"}
model_kwargs={"image_dimensions": "512x512"}
)
"""
model: str
input: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: Dict[str, Any] = Field(default_factory=dict, alias="input")
replicate_api_token: Optional[str] = None
prompt_key: Optional[str] = None
version_obj: Any = Field(default=None, exclude=True)
@ -59,6 +58,7 @@ class Replicate(LLM):
class Config:
"""Configuration for this pydantic config."""
allow_population_by_field_name = True
extra = Extra.forbid
@property
@ -74,7 +74,12 @@ class Replicate(LLM):
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {})
input = values.pop("input", {})
if input:
logger.warning(
"Init param `input` is deprecated, please use `model_kwargs` instead."
)
extra = {**values.get("model_kwargs", {}), **input}
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
@ -202,7 +207,11 @@ class Replicate(LLM):
self.prompt_key = input_properties[0][0]
input_: Dict = {self.prompt_key: prompt, **self.input, **kwargs}
input_: Dict = {
self.prompt_key: prompt,
**self.model_kwargs,
**kwargs,
}
return replicate_python.predictions.create(
version=self.version_obj, input=input_
)

View File

@ -24,3 +24,21 @@ def test_replicate_streaming_call() -> None:
output = llm("What is LangChain")
assert output
assert isinstance(output, str)
def test_replicate_model_kwargs() -> None:
"""Test simple non-streaming call to Replicate."""
llm = Replicate(
model=TEST_MODEL, model_kwargs={"max_length": 100, "temperature": 0.01}
)
long_output = llm("What is LangChain")
llm = Replicate(
model=TEST_MODEL, model_kwargs={"max_length": 10, "temperature": 0.01}
)
short_output = llm("What is LangChain")
assert len(short_output) < len(long_output)
def test_replicate_input() -> None:
llm = Replicate(model=TEST_MODEL, input={"max_length": 10})
assert llm.model_kwargs == {"max_length": 10}