Replicate params fix (#10603)

pull/10605/head
Bagatur 1 year ago committed by GitHub
parent 50bb704da5
commit ecbb1ed8cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -79,7 +79,7 @@ class Replicate(LLM):
logger.warning(
"Init param `input` is deprecated, please use `model_kwargs` instead."
)
extra = {**values.get("model_kwargs", {}), **input}
extra = {**values.pop("model_kwargs", {}), **input}
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
@ -96,7 +96,7 @@ class Replicate(LLM):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
replicate_api_token = get_from_dict_or_env(
values, "REPLICATE_API_TOKEN", "REPLICATE_API_TOKEN"
values, "replicate_api_token", "REPLICATE_API_TOKEN"
)
values["replicate_api_token"] = replicate_api_token
return values

@ -37,6 +37,7 @@ def test_replicate_model_kwargs() -> None:
)
short_output = llm("What is LangChain")
assert len(short_output) < len(long_output)
assert llm.model_kwargs == {"max_length": 10, "temperature": 0.01}
def test_replicate_input() -> None:

Loading…
Cancel
Save