From 5efaedf488d4aacaa887d4aa46ca4aa32562fd1d Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Thu, 7 Dec 2023 01:23:17 +0900 Subject: [PATCH] Exclude `max_tokens` from request if it's None (#14334) We found a request with `max_tokens=None` results in the following error in Anthropic: ``` HTTPError: 400 Client Error: Bad Request for url: https://oregon.staging.cloud.databricks.com/serving-endpoints/corey-anthropic/invocations. Response text: {"error_code":"INVALID_PARAMETER_VALUE","message":"INVALID_PARAMETER_VALUE: max_tokens was not of type Integer: null"} ``` This PR excludes `max_tokens` if it's None. --- libs/langchain/langchain/chat_models/mlflow.py | 4 ++-- libs/langchain/langchain/llms/databricks.py | 15 +++++++-------- libs/langchain/langchain/llms/mlflow.py | 4 +++- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/libs/langchain/langchain/chat_models/mlflow.py b/libs/langchain/langchain/chat_models/mlflow.py index e1c1ad1542..4aa42f7a96 100644 --- a/libs/langchain/langchain/chat_models/mlflow.py +++ b/libs/langchain/langchain/chat_models/mlflow.py @@ -115,13 +115,13 @@ class ChatMlflow(BaseChatModel): "messages": message_dicts, "temperature": self.temperature, "n": self.n, - "stop": stop or self.stop, - "max_tokens": self.max_tokens, **self.extra_params, **kwargs, } if stop := self.stop or stop: data["stop"] = stop + if self.max_tokens is not None: + data["max_tokens"] = self.max_tokens resp = self._client.predict(endpoint=self.endpoint, inputs=data) return ChatMlflow._create_chat_result(resp) diff --git a/libs/langchain/langchain/llms/databricks.py b/libs/langchain/langchain/llms/databricks.py index d83e67a6cc..a3f505b5c2 100644 --- a/libs/langchain/langchain/llms/databricks.py +++ b/libs/langchain/langchain/llms/databricks.py @@ -334,13 +334,14 @@ class Databricks(LLM): @property def _llm_params(self) -> Dict[str, Any]: - params = { + params: Dict[str, Any] = { "temperature": self.temperature, "n": self.n, - "stop": self.stop, - "max_tokens": self.max_tokens, - **(self.model_kwargs or self.extra_params), } + if self.stop: + params["stop"] = self.stop + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens return params @validator("cluster_id", always=True) @@ -457,11 +458,9 @@ class Databricks(LLM): request: Dict[str, Any] = {"prompt": prompt} if self._client.llm: request.update(self._llm_params) - request.update(self.model_kwargs or self.extra_params) - else: - request.update(self.model_kwargs or self.extra_params) + request.update(self.model_kwargs or self.extra_params) request.update(kwargs) - if stop := self.stop or stop: + if stop: request["stop"] = stop if self.transform_input_fn: diff --git a/libs/langchain/langchain/llms/mlflow.py b/libs/langchain/langchain/llms/mlflow.py index 565a4b3a36..00e16bcb46 100644 --- a/libs/langchain/langchain/llms/mlflow.py +++ b/libs/langchain/langchain/llms/mlflow.py @@ -106,12 +106,14 @@ class Mlflow(LLM): "prompt": prompt, "temperature": self.temperature, "n": self.n, - "max_tokens": self.max_tokens, **self.extra_params, **kwargs, } if stop := self.stop or stop: data["stop"] = stop + if self.max_tokens is not None: + data["max_tokens"] = self.max_tokens + resp = self._client.predict(endpoint=self.endpoint, inputs=data) return resp["choices"][0]["text"]