diff --git a/libs/langchain/langchain/llms/baseten.py b/libs/langchain/langchain/llms/baseten.py index d7d768d5ec..d07fd63366 100644 --- a/libs/langchain/langchain/llms/baseten.py +++ b/libs/langchain/langchain/llms/baseten.py @@ -67,8 +67,8 @@ class Baseten(LLM): # get the model and version try: model = baseten.deployed_model_version_id(self.model) - response = model.predict({"prompt": prompt}) + response = model.predict({"prompt": prompt, **kwargs}) except baseten.common.core.ApiError: model = baseten.deployed_model_id(self.model) - response = model.predict({"prompt": prompt}) + response = model.predict({"prompt": prompt, **kwargs}) return "".join(response)