cache replicate version (#10517)

In subsequent pr will update _call to use replicate.run directly when
not streaming, so version object isn't needed at all

cc @cbh123 @tjaffri
This commit is contained in:
Bagatur 2023-09-14 08:34:04 -07:00 committed by GitHub
parent 49b65a1b57
commit ccf71e23e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
@ -23,10 +25,14 @@ class Replicate(LLM):
.. code-block:: python
from langchain.llms import Replicate
replicate = Replicate(model="stability-ai/stable-diffusion: \
27b93a2413e7f36cd83da926f365628\
0b2931564ff050bf9575f1fdf9bcd7478",
input={"image_dimensions": "512x512"})
replicate = Replicate(
model=(
"stability-ai/stable-diffusion: "
"27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
),
input={"image_dimensions": "512x512"}
)
"""
model: str
@ -34,6 +40,11 @@ class Replicate(LLM):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
replicate_api_token: Optional[str] = None
prompt_key: Optional[str] = None
version_obj: Any = Field(default=None, exclude=True)
"""Optionally pass in the model version object during initialization to avoid
having to make an extra API call to retrieve it during streaming. NOTE: not
serializable, is excluded from serialization.
"""
streaming: bool = Field(default=False)
"""Whether to stream the results."""
@ -111,14 +122,15 @@ class Replicate(LLM):
)
# get the model and version
if self.version_obj is None:
model_str, version_str = self.model.split(":")
model = replicate_python.models.get(model_str)
version = model.versions.get(version_str)
self.version_obj = model.versions.get(version_str)
if not self.prompt_key:
if self.prompt_key is None:
# sort through the openapi schema to get the name of the first input
input_properties = sorted(
version.openapi_schema["components"]["schemas"]["Input"][
self.version_obj.openapi_schema["components"]["schemas"]["Input"][
"properties"
].items(),
key=lambda item: item[1].get("x-order", 0),
@ -129,7 +141,7 @@ class Replicate(LLM):
inputs: Dict = {self.prompt_key: prompt, **self.input}
prediction = replicate_python.predictions.create(
version=version, input={**inputs, **kwargs}
version=self.version_obj, input={**inputs, **kwargs}
)
current_completion: str = ""
stop_condition_reached = False