mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Fix replicate model kwargs (#10599)
This commit is contained in:
parent
77a165e0d9
commit
e195b78e1d
@ -23,7 +23,7 @@ class Replicate(LLM):
|
|||||||
You can find your token here: https://replicate.com/account
|
You can find your token here: https://replicate.com/account
|
||||||
|
|
||||||
The model param is required, but any other model parameters can also
|
The model param is required, but any other model parameters can also
|
||||||
be passed in with the format input={model_param: value, ...}
|
be passed in with the format model_kwargs={model_param: value, ...}
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -35,13 +35,12 @@ class Replicate(LLM):
|
|||||||
"stability-ai/stable-diffusion: "
|
"stability-ai/stable-diffusion: "
|
||||||
"27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
|
"27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
|
||||||
),
|
),
|
||||||
input={"image_dimensions": "512x512"}
|
model_kwargs={"image_dimensions": "512x512"}
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
input: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict, alias="input")
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
||||||
replicate_api_token: Optional[str] = None
|
replicate_api_token: Optional[str] = None
|
||||||
prompt_key: Optional[str] = None
|
prompt_key: Optional[str] = None
|
||||||
version_obj: Any = Field(default=None, exclude=True)
|
version_obj: Any = Field(default=None, exclude=True)
|
||||||
@ -59,6 +58,7 @@ class Replicate(LLM):
|
|||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic config."""
|
"""Configuration for this pydantic config."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -74,7 +74,12 @@ class Replicate(LLM):
|
|||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
input = values.pop("input", {})
|
||||||
|
if input:
|
||||||
|
logger.warning(
|
||||||
|
"Init param `input` is deprecated, please use `model_kwargs` instead."
|
||||||
|
)
|
||||||
|
extra = {**values.get("model_kwargs", {}), **input}
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
if field_name not in all_required_field_names:
|
if field_name not in all_required_field_names:
|
||||||
if field_name in extra:
|
if field_name in extra:
|
||||||
@ -202,7 +207,11 @@ class Replicate(LLM):
|
|||||||
|
|
||||||
self.prompt_key = input_properties[0][0]
|
self.prompt_key = input_properties[0][0]
|
||||||
|
|
||||||
input_: Dict = {self.prompt_key: prompt, **self.input, **kwargs}
|
input_: Dict = {
|
||||||
|
self.prompt_key: prompt,
|
||||||
|
**self.model_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
return replicate_python.predictions.create(
|
return replicate_python.predictions.create(
|
||||||
version=self.version_obj, input=input_
|
version=self.version_obj, input=input_
|
||||||
)
|
)
|
||||||
|
@ -24,3 +24,21 @@ def test_replicate_streaming_call() -> None:
|
|||||||
output = llm("What is LangChain")
|
output = llm("What is LangChain")
|
||||||
assert output
|
assert output
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_replicate_model_kwargs() -> None:
|
||||||
|
"""Test simple non-streaming call to Replicate."""
|
||||||
|
llm = Replicate(
|
||||||
|
model=TEST_MODEL, model_kwargs={"max_length": 100, "temperature": 0.01}
|
||||||
|
)
|
||||||
|
long_output = llm("What is LangChain")
|
||||||
|
llm = Replicate(
|
||||||
|
model=TEST_MODEL, model_kwargs={"max_length": 10, "temperature": 0.01}
|
||||||
|
)
|
||||||
|
short_output = llm("What is LangChain")
|
||||||
|
assert len(short_output) < len(long_output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_replicate_input() -> None:
|
||||||
|
llm = Replicate(model=TEST_MODEL, input={"max_length": 10})
|
||||||
|
assert llm.model_kwargs == {"max_length": 10}
|
||||||
|
Loading…
Reference in New Issue
Block a user