mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Fix fine-tuned replicate models with faster cold boot (#10512)
With the latest support for faster cold boot in replicate https://replicate.com/blog/fine-tune-cold-boots it looks like the replicate LLM support in langchain is broken since some internal replicate inputs are being returned. Screenshot below illustrates the problem: <img width="1917" alt="image" src="https://github.com/langchain-ai/langchain/assets/749277/d28c27cc-40fb-4258-8710-844c00d3c2b0"> As you can see, the new replicate_weights param is being sent down with x-order = 0 (which is causing langchain to use that param instead of prompt which is x-order = 1) FYI @baskaryan this requires a fix otherwise replicate is broken for these models. I have pinged replicate whether they want to fix it on their end by changing the x-order returned by them. Update: per suggestion I updated the PR to just allow manually setting the prompt_key which can be set to "prompt" in this case by callers... I think this is going to be faster anyway than trying to dynamically query the model every time if you know the prompt key for your model. --------- Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
parent
57e2de2077
commit
21fbbe83a7
@ -33,6 +33,7 @@ class Replicate(LLM):
|
|||||||
input: Dict[str, Any] = Field(default_factory=dict)
|
input: Dict[str, Any] = Field(default_factory=dict)
|
||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
replicate_api_token: Optional[str] = None
|
replicate_api_token: Optional[str] = None
|
||||||
|
prompt_key: Optional[str] = None
|
||||||
|
|
||||||
streaming: bool = Field(default=False)
|
streaming: bool = Field(default=False)
|
||||||
"""Whether to stream the results."""
|
"""Whether to stream the results."""
|
||||||
@ -114,15 +115,18 @@ class Replicate(LLM):
|
|||||||
model = replicate_python.models.get(model_str)
|
model = replicate_python.models.get(model_str)
|
||||||
version = model.versions.get(version_str)
|
version = model.versions.get(version_str)
|
||||||
|
|
||||||
# sort through the openapi schema to get the name of the first input
|
if not self.prompt_key:
|
||||||
input_properties = sorted(
|
# sort through the openapi schema to get the name of the first input
|
||||||
version.openapi_schema["components"]["schemas"]["Input"][
|
input_properties = sorted(
|
||||||
"properties"
|
version.openapi_schema["components"]["schemas"]["Input"][
|
||||||
].items(),
|
"properties"
|
||||||
key=lambda item: item[1].get("x-order", 0),
|
].items(),
|
||||||
)
|
key=lambda item: item[1].get("x-order", 0),
|
||||||
first_input_name = input_properties[0][0]
|
)
|
||||||
inputs = {first_input_name: prompt, **self.input}
|
|
||||||
|
self.prompt_key = input_properties[0][0]
|
||||||
|
|
||||||
|
inputs = {self.prompt_key: prompt, **self.input}
|
||||||
|
|
||||||
prediction = replicate_python.predictions.create(
|
prediction = replicate_python.predictions.create(
|
||||||
version=version, input={**inputs, **kwargs}
|
version=version, input={**inputs, **kwargs}
|
||||||
|
Loading…
Reference in New Issue
Block a user