From 21fbbe83a7adda89863757a8fdb8a9f762dc5f02 Mon Sep 17 00:00:00 2001 From: Taqi Jaffri Date: Tue, 12 Sep 2023 15:40:55 -0700 Subject: [PATCH 1/2] Fix fine-tuned replicate models with faster cold boot (#10512) With the latest support for faster cold boot in replicate https://replicate.com/blog/fine-tune-cold-boots it looks like the replicate LLM support in langchain is broken since some internal replicate inputs are being returned. Screenshot below illustrates the problem: image As you can see, the new replicate_weights param is being sent down with x-order = 0 (which is causing langchain to use that param instead of prompt which is x-order = 1) FYI @baskaryan this requires a fix otherwise replicate is broken for these models. I have pinged replicate whether they want to fix it on their end by changing the x-order returned by them. Update: per suggestion I updated the PR to just allow manually setting the prompt_key which can be set to "prompt" in this case by callers... I think this is going to be faster anyway than trying to dynamically query the model every time if you know the prompt key for your model. --------- Co-authored-by: Taqi Jaffri --- libs/langchain/langchain/llms/replicate.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/libs/langchain/langchain/llms/replicate.py b/libs/langchain/langchain/llms/replicate.py index 9fa2807b4e..c5b8a8da6a 100644 --- a/libs/langchain/langchain/llms/replicate.py +++ b/libs/langchain/langchain/llms/replicate.py @@ -33,6 +33,7 @@ class Replicate(LLM): input: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict) replicate_api_token: Optional[str] = None + prompt_key: Optional[str] = None streaming: bool = Field(default=False) """Whether to stream the results.""" @@ -114,15 +115,18 @@ class Replicate(LLM): model = replicate_python.models.get(model_str) version = model.versions.get(version_str) - # sort through the openapi schema to get the name of the first input - input_properties = sorted( - version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), - key=lambda item: item[1].get("x-order", 0), - ) - first_input_name = input_properties[0][0] - inputs = {first_input_name: prompt, **self.input} + if not self.prompt_key: + # sort through the openapi schema to get the name of the first input + input_properties = sorted( + version.openapi_schema["components"]["schemas"]["Input"][ + "properties" + ].items(), + key=lambda item: item[1].get("x-order", 0), + ) + + self.prompt_key = input_properties[0][0] + + inputs = {self.prompt_key: prompt, **self.input} prediction = replicate_python.predictions.create( version=version, input={**inputs, **kwargs} From 7ecee7821a9c71b23c71772179799e0657458078 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 12 Sep 2023 15:46:36 -0700 Subject: [PATCH 2/2] Replicate fix linting --- libs/langchain/langchain/llms/replicate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/llms/replicate.py b/libs/langchain/langchain/llms/replicate.py index c5b8a8da6a..4ce4621d16 100644 --- a/libs/langchain/langchain/llms/replicate.py +++ b/libs/langchain/langchain/llms/replicate.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -82,7 +82,7 @@ class Replicate(LLM): return values @property - def _identifying_params(self) -> Mapping[str, Any]: + def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return { "model": self.model, @@ -126,7 +126,7 @@ class Replicate(LLM): self.prompt_key = input_properties[0][0] - inputs = {self.prompt_key: prompt, **self.input} + inputs: Dict = {self.prompt_key: prompt, **self.input} prediction = replicate_python.predictions.create( version=version, input={**inputs, **kwargs}