mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Fix replicate model kwargs (#10599)
This commit is contained in:
parent
77a165e0d9
commit
e195b78e1d
@ -23,7 +23,7 @@ class Replicate(LLM):
|
||||
You can find your token here: https://replicate.com/account
|
||||
|
||||
The model param is required, but any other model parameters can also
|
||||
be passed in with the format input={model_param: value, ...}
|
||||
be passed in with the format model_kwargs={model_param: value, ...}
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@ -35,13 +35,12 @@ class Replicate(LLM):
|
||||
"stability-ai/stable-diffusion: "
|
||||
"27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
|
||||
),
|
||||
input={"image_dimensions": "512x512"}
|
||||
model_kwargs={"image_dimensions": "512x512"}
|
||||
)
|
||||
"""
|
||||
|
||||
model: str
|
||||
input: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict, alias="input")
|
||||
replicate_api_token: Optional[str] = None
|
||||
prompt_key: Optional[str] = None
|
||||
version_obj: Any = Field(default=None, exclude=True)
|
||||
@ -59,6 +58,7 @@ class Replicate(LLM):
|
||||
class Config:
|
||||
"""Configuration for this pydantic config."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
extra = Extra.forbid
|
||||
|
||||
@property
|
||||
@ -74,7 +74,12 @@ class Replicate(LLM):
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
input = values.pop("input", {})
|
||||
if input:
|
||||
logger.warning(
|
||||
"Init param `input` is deprecated, please use `model_kwargs` instead."
|
||||
)
|
||||
extra = {**values.get("model_kwargs", {}), **input}
|
||||
for field_name in list(values):
|
||||
if field_name not in all_required_field_names:
|
||||
if field_name in extra:
|
||||
@ -202,7 +207,11 @@ class Replicate(LLM):
|
||||
|
||||
self.prompt_key = input_properties[0][0]
|
||||
|
||||
input_: Dict = {self.prompt_key: prompt, **self.input, **kwargs}
|
||||
input_: Dict = {
|
||||
self.prompt_key: prompt,
|
||||
**self.model_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
return replicate_python.predictions.create(
|
||||
version=self.version_obj, input=input_
|
||||
)
|
||||
|
@ -24,3 +24,21 @@ def test_replicate_streaming_call() -> None:
|
||||
output = llm("What is LangChain")
|
||||
assert output
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_replicate_model_kwargs() -> None:
|
||||
"""Test simple non-streaming call to Replicate."""
|
||||
llm = Replicate(
|
||||
model=TEST_MODEL, model_kwargs={"max_length": 100, "temperature": 0.01}
|
||||
)
|
||||
long_output = llm("What is LangChain")
|
||||
llm = Replicate(
|
||||
model=TEST_MODEL, model_kwargs={"max_length": 10, "temperature": 0.01}
|
||||
)
|
||||
short_output = llm("What is LangChain")
|
||||
assert len(short_output) < len(long_output)
|
||||
|
||||
|
||||
def test_replicate_input() -> None:
|
||||
llm = Replicate(model=TEST_MODEL, input={"max_length": 10})
|
||||
assert llm.model_kwargs == {"max_length": 10}
|
||||
|
Loading…
Reference in New Issue
Block a user