Allow replicate prompt key to be manually specified (#10516)

Since inference logic doesn't work for all models

Co-authored-by: Taqi Jaffri <tjaffri@gmail.com>
Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
pull/10491/head^2
Bagatur 12 months ago committed by GitHub
commit eaf916f999
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
@ -33,6 +33,7 @@ class Replicate(LLM):
input: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
replicate_api_token: Optional[str] = None
prompt_key: Optional[str] = None
streaming: bool = Field(default=False)
"""Whether to stream the results."""
@ -81,7 +82,7 @@ class Replicate(LLM):
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
@ -114,15 +115,18 @@ class Replicate(LLM):
model = replicate_python.models.get(model_str)
version = model.versions.get(version_str)
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)
first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input}
if not self.prompt_key:
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
)
self.prompt_key = input_properties[0][0]
inputs: Dict = {self.prompt_key: prompt, **self.input}
prediction = replicate_python.predictions.create(
version=version, input={**inputs, **kwargs}

Loading…
Cancel
Save