mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
cache replicate version (#10517)
In subsequent pr will update _call to use replicate.run directly when not streaming, so version object isn't needed at all cc @cbh123 @tjaffri
This commit is contained in:
parent
49b65a1b57
commit
ccf71e23e8
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@ -23,10 +25,14 @@ class Replicate(LLM):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import Replicate
|
||||
replicate = Replicate(model="stability-ai/stable-diffusion: \
|
||||
27b93a2413e7f36cd83da926f365628\
|
||||
0b2931564ff050bf9575f1fdf9bcd7478",
|
||||
input={"image_dimensions": "512x512"})
|
||||
|
||||
replicate = Replicate(
|
||||
model=(
|
||||
"stability-ai/stable-diffusion: "
|
||||
"27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
|
||||
),
|
||||
input={"image_dimensions": "512x512"}
|
||||
)
|
||||
"""
|
||||
|
||||
model: str
|
||||
@ -34,6 +40,11 @@ class Replicate(LLM):
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
replicate_api_token: Optional[str] = None
|
||||
prompt_key: Optional[str] = None
|
||||
version_obj: Any = Field(default=None, exclude=True)
|
||||
"""Optionally pass in the model version object during initialization to avoid
|
||||
having to make an extra API call to retrieve it during streaming. NOTE: not
|
||||
serializable, is excluded from serialization.
|
||||
"""
|
||||
|
||||
streaming: bool = Field(default=False)
|
||||
"""Whether to stream the results."""
|
||||
@ -111,14 +122,15 @@ class Replicate(LLM):
|
||||
)
|
||||
|
||||
# get the model and version
|
||||
if self.version_obj is None:
|
||||
model_str, version_str = self.model.split(":")
|
||||
model = replicate_python.models.get(model_str)
|
||||
version = model.versions.get(version_str)
|
||||
self.version_obj = model.versions.get(version_str)
|
||||
|
||||
if not self.prompt_key:
|
||||
if self.prompt_key is None:
|
||||
# sort through the openapi schema to get the name of the first input
|
||||
input_properties = sorted(
|
||||
version.openapi_schema["components"]["schemas"]["Input"][
|
||||
self.version_obj.openapi_schema["components"]["schemas"]["Input"][
|
||||
"properties"
|
||||
].items(),
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
@ -129,7 +141,7 @@ class Replicate(LLM):
|
||||
inputs: Dict = {self.prompt_key: prompt, **self.input}
|
||||
|
||||
prediction = replicate_python.predictions.create(
|
||||
version=version, input={**inputs, **kwargs}
|
||||
version=self.version_obj, input={**inputs, **kwargs}
|
||||
)
|
||||
current_completion: str = ""
|
||||
stop_condition_reached = False
|
||||
|
Loading…
Reference in New Issue
Block a user