mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
cache replicate version (#10517)
In subsequent pr will update _call to use replicate.run directly when not streaming, so version object isn't needed at all cc @cbh123 @tjaffri
This commit is contained in:
parent
49b65a1b57
commit
ccf71e23e8
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
@ -23,10 +25,14 @@ class Replicate(LLM):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.llms import Replicate
|
from langchain.llms import Replicate
|
||||||
replicate = Replicate(model="stability-ai/stable-diffusion: \
|
|
||||||
27b93a2413e7f36cd83da926f365628\
|
replicate = Replicate(
|
||||||
0b2931564ff050bf9575f1fdf9bcd7478",
|
model=(
|
||||||
input={"image_dimensions": "512x512"})
|
"stability-ai/stable-diffusion: "
|
||||||
|
"27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
|
||||||
|
),
|
||||||
|
input={"image_dimensions": "512x512"}
|
||||||
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
@ -34,6 +40,11 @@ class Replicate(LLM):
|
|||||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
replicate_api_token: Optional[str] = None
|
replicate_api_token: Optional[str] = None
|
||||||
prompt_key: Optional[str] = None
|
prompt_key: Optional[str] = None
|
||||||
|
version_obj: Any = Field(default=None, exclude=True)
|
||||||
|
"""Optionally pass in the model version object during initialization to avoid
|
||||||
|
having to make an extra API call to retrieve it during streaming. NOTE: not
|
||||||
|
serializable, is excluded from serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
streaming: bool = Field(default=False)
|
streaming: bool = Field(default=False)
|
||||||
"""Whether to stream the results."""
|
"""Whether to stream the results."""
|
||||||
@ -111,14 +122,15 @@ class Replicate(LLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get the model and version
|
# get the model and version
|
||||||
model_str, version_str = self.model.split(":")
|
if self.version_obj is None:
|
||||||
model = replicate_python.models.get(model_str)
|
model_str, version_str = self.model.split(":")
|
||||||
version = model.versions.get(version_str)
|
model = replicate_python.models.get(model_str)
|
||||||
|
self.version_obj = model.versions.get(version_str)
|
||||||
|
|
||||||
if not self.prompt_key:
|
if self.prompt_key is None:
|
||||||
# sort through the openapi schema to get the name of the first input
|
# sort through the openapi schema to get the name of the first input
|
||||||
input_properties = sorted(
|
input_properties = sorted(
|
||||||
version.openapi_schema["components"]["schemas"]["Input"][
|
self.version_obj.openapi_schema["components"]["schemas"]["Input"][
|
||||||
"properties"
|
"properties"
|
||||||
].items(),
|
].items(),
|
||||||
key=lambda item: item[1].get("x-order", 0),
|
key=lambda item: item[1].get("x-order", 0),
|
||||||
@ -129,7 +141,7 @@ class Replicate(LLM):
|
|||||||
inputs: Dict = {self.prompt_key: prompt, **self.input}
|
inputs: Dict = {self.prompt_key: prompt, **self.input}
|
||||||
|
|
||||||
prediction = replicate_python.predictions.create(
|
prediction = replicate_python.predictions.create(
|
||||||
version=version, input={**inputs, **kwargs}
|
version=self.version_obj, input={**inputs, **kwargs}
|
||||||
)
|
)
|
||||||
current_completion: str = ""
|
current_completion: str = ""
|
||||||
stop_condition_reached = False
|
stop_condition_reached = False
|
||||||
|
Loading…
Reference in New Issue
Block a user