Fix replicate model kwargs (#10599)

This commit is contained in:
Bagatur 2023-09-14 14:43:42 -07:00 committed by GitHub
parent 77a165e0d9
commit e195b78e1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 6 deletions

View File

@ -23,7 +23,7 @@ class Replicate(LLM):
You can find your token here: https://replicate.com/account You can find your token here: https://replicate.com/account
The model param is required, but any other model parameters can also The model param is required, but any other model parameters can also
be passed in with the format input={model_param: value, ...} be passed in with the format model_kwargs={model_param: value, ...}
Example: Example:
.. code-block:: python .. code-block:: python
@ -35,13 +35,12 @@ class Replicate(LLM):
"stability-ai/stable-diffusion: " "stability-ai/stable-diffusion: "
"27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
), ),
input={"image_dimensions": "512x512"} model_kwargs={"image_dimensions": "512x512"}
) )
""" """
model: str model: str
input: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict, alias="input")
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
replicate_api_token: Optional[str] = None replicate_api_token: Optional[str] = None
prompt_key: Optional[str] = None prompt_key: Optional[str] = None
version_obj: Any = Field(default=None, exclude=True) version_obj: Any = Field(default=None, exclude=True)
@ -59,6 +58,7 @@ class Replicate(LLM):
class Config: class Config:
"""Configuration for this pydantic config.""" """Configuration for this pydantic config."""
allow_population_by_field_name = True
extra = Extra.forbid extra = Extra.forbid
@property @property
@ -74,7 +74,12 @@ class Replicate(LLM):
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()} all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {}) input = values.pop("input", {})
if input:
logger.warning(
"Init param `input` is deprecated, please use `model_kwargs` instead."
)
extra = {**values.get("model_kwargs", {}), **input}
for field_name in list(values): for field_name in list(values):
if field_name not in all_required_field_names: if field_name not in all_required_field_names:
if field_name in extra: if field_name in extra:
@ -202,7 +207,11 @@ class Replicate(LLM):
self.prompt_key = input_properties[0][0] self.prompt_key = input_properties[0][0]
input_: Dict = {self.prompt_key: prompt, **self.input, **kwargs} input_: Dict = {
self.prompt_key: prompt,
**self.model_kwargs,
**kwargs,
}
return replicate_python.predictions.create( return replicate_python.predictions.create(
version=self.version_obj, input=input_ version=self.version_obj, input=input_
) )

View File

@ -24,3 +24,21 @@ def test_replicate_streaming_call() -> None:
output = llm("What is LangChain") output = llm("What is LangChain")
assert output assert output
assert isinstance(output, str) assert isinstance(output, str)
def test_replicate_model_kwargs() -> None:
"""Test simple non-streaming call to Replicate."""
llm = Replicate(
model=TEST_MODEL, model_kwargs={"max_length": 100, "temperature": 0.01}
)
long_output = llm("What is LangChain")
llm = Replicate(
model=TEST_MODEL, model_kwargs={"max_length": 10, "temperature": 0.01}
)
short_output = llm("What is LangChain")
assert len(short_output) < len(long_output)
def test_replicate_input() -> None:
llm = Replicate(model=TEST_MODEL, input={"max_length": 10})
assert llm.model_kwargs == {"max_length": 10}