diff --git a/libs/langchain/langchain/llms/replicate.py b/libs/langchain/langchain/llms/replicate.py index 5d407c40b4..7a146070ec 100644 --- a/libs/langchain/langchain/llms/replicate.py +++ b/libs/langchain/langchain/llms/replicate.py @@ -79,7 +79,7 @@ class Replicate(LLM): logger.warning( "Init param `input` is deprecated, please use `model_kwargs` instead." ) - extra = {**values.get("model_kwargs", {}), **input} + extra = {**values.pop("model_kwargs", {}), **input} for field_name in list(values): if field_name not in all_required_field_names: if field_name in extra: @@ -96,7 +96,7 @@ class Replicate(LLM): def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" replicate_api_token = get_from_dict_or_env( - values, "REPLICATE_API_TOKEN", "REPLICATE_API_TOKEN" + values, "replicate_api_token", "REPLICATE_API_TOKEN" ) values["replicate_api_token"] = replicate_api_token return values diff --git a/libs/langchain/tests/integration_tests/llms/test_replicate.py b/libs/langchain/tests/integration_tests/llms/test_replicate.py index 9bc183bb8b..eaa09fc597 100644 --- a/libs/langchain/tests/integration_tests/llms/test_replicate.py +++ b/libs/langchain/tests/integration_tests/llms/test_replicate.py @@ -37,6 +37,7 @@ def test_replicate_model_kwargs() -> None: ) short_output = llm("What is LangChain") assert len(short_output) < len(long_output) + assert llm.model_kwargs == {"max_length": 10, "temperature": 0.01} def test_replicate_input() -> None: