from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain_core.utils import get_from_dict_or_env if TYPE_CHECKING: from replicate.prediction import Prediction logger = logging.getLogger(__name__) class Replicate(LLM): """Replicate models. To use, you should have the ``replicate`` python package installed, and the environment variable ``REPLICATE_API_TOKEN`` set with your API token. You can find your token here: https://replicate.com/account The model param is required, but any other model parameters can also be passed in with the format model_kwargs={model_param: value, ...} Example: .. code-block:: python from langchain_community.llms import Replicate replicate = Replicate( model=( "stability-ai/stable-diffusion: " "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", ), model_kwargs={"image_dimensions": "512x512"} ) """ model: str model_kwargs: Dict[str, Any] = Field(default_factory=dict, alias="input") replicate_api_token: Optional[str] = None prompt_key: Optional[str] = None version_obj: Any = Field(default=None, exclude=True) """Optionally pass in the model version object during initialization to avoid having to make an extra API call to retrieve it during streaming. NOTE: not serializable, is excluded from serialization. """ streaming: bool = False """Whether to stream the results.""" stop: List[str] = Field(default_factory=list) """Stop sequences to early-terminate generation.""" class Config: """Configuration for this pydantic config.""" allow_population_by_field_name = True extra = Extra.forbid @property def lc_secrets(self) -> Dict[str, str]: return {"replicate_api_token": "REPLICATE_API_TOKEN"} @classmethod def is_lc_serializable(cls) -> bool: return True @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "llms", "replicate"] @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = {field.alias for field in cls.__fields__.values()} input = values.pop("input", {}) if input: logger.warning( "Init param `input` is deprecated, please use `model_kwargs` instead." ) extra = {**values.pop("model_kwargs", {}), **input} for field_name in list(values): if field_name not in all_required_field_names: if field_name in extra: raise ValueError(f"Found {field_name} supplied twice.") logger.warning( f"""{field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""" ) extra[field_name] = values.pop(field_name) values["model_kwargs"] = extra return values @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" replicate_api_token = get_from_dict_or_env( values, "replicate_api_token", "REPLICATE_API_TOKEN" ) values["replicate_api_token"] = replicate_api_token return values @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return { "model": self.model, "model_kwargs": self.model_kwargs, } @property def _llm_type(self) -> str: """Return type of model.""" return "replicate" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Call to replicate endpoint.""" if self.streaming: completion: Optional[str] = None for chunk in self._stream( prompt, stop=stop, run_manager=run_manager, **kwargs ): if completion is None: completion = chunk.text else: completion += chunk.text else: prediction = self._create_prediction(prompt, **kwargs) prediction.wait() if prediction.status == "failed": raise RuntimeError(prediction.error) if isinstance(prediction.output, str): completion = prediction.output else: completion = "".join(prediction.output) assert completion is not None stop_conditions = stop or self.stop for s in stop_conditions: if s in completion: completion = completion[: completion.find(s)] return completion def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: prediction = self._create_prediction(prompt, **kwargs) stop_conditions = stop or self.stop stop_condition_reached = False current_completion: str = "" for output in prediction.output_iterator(): current_completion += output # test for stop conditions, if specified for s in stop_conditions: if s in current_completion: prediction.cancel() stop_condition_reached = True # Potentially some tokens that should still be yielded before ending # stream. stop_index = max(output.find(s), 0) output = output[:stop_index] if not output: break if output: if run_manager: run_manager.on_llm_new_token( output, verbose=self.verbose, ) yield GenerationChunk(text=output) if stop_condition_reached: break def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction: try: import replicate as replicate_python except ImportError: raise ImportError( "Could not import replicate python package. " "Please install it with `pip install replicate`." ) # get the model and version if self.version_obj is None: if ":" in self.model: model_str, version_str = self.model.split(":") model = replicate_python.models.get(model_str) self.version_obj = model.versions.get(version_str) else: model = replicate_python.models.get(self.model) self.version_obj = model.latest_version if self.prompt_key is None: # sort through the openapi schema to get the name of the first input input_properties = sorted( self.version_obj.openapi_schema["components"]["schemas"]["Input"][ "properties" ].items(), key=lambda item: item[1].get("x-order", 0), ) self.prompt_key = input_properties[0][0] input_: Dict = { self.prompt_key: prompt, **self.model_kwargs, **kwargs, } # if it's an official model if ":" not in self.model: return replicate_python.models.predictions.create(self.model, input=input_) else: return replicate_python.predictions.create( version=self.version_obj, input=input_ )