mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Replicate params fix (#10603)
This commit is contained in:
parent
50bb704da5
commit
ecbb1ed8cb
@ -79,7 +79,7 @@ class Replicate(LLM):
|
||||
logger.warning(
|
||||
"Init param `input` is deprecated, please use `model_kwargs` instead."
|
||||
)
|
||||
extra = {**values.get("model_kwargs", {}), **input}
|
||||
extra = {**values.pop("model_kwargs", {}), **input}
|
||||
for field_name in list(values):
|
||||
if field_name not in all_required_field_names:
|
||||
if field_name in extra:
|
||||
@ -96,7 +96,7 @@ class Replicate(LLM):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
replicate_api_token = get_from_dict_or_env(
|
||||
values, "REPLICATE_API_TOKEN", "REPLICATE_API_TOKEN"
|
||||
values, "replicate_api_token", "REPLICATE_API_TOKEN"
|
||||
)
|
||||
values["replicate_api_token"] = replicate_api_token
|
||||
return values
|
||||
|
@ -37,6 +37,7 @@ def test_replicate_model_kwargs() -> None:
|
||||
)
|
||||
short_output = llm("What is LangChain")
|
||||
assert len(short_output) < len(long_output)
|
||||
assert llm.model_kwargs == {"max_length": 10, "temperature": 0.01}
|
||||
|
||||
|
||||
def test_replicate_input() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user