add replicate stream (#10518)

support direct replicate streaming. cc @cbh123 @tjaffri
pull/10586/head
Bagatur 1 year ago committed by GitHub
parent 7f3f6097e7
commit 9dd4cacae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,13 +1,17 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
from replicate.prediction import Prediction
logger = logging.getLogger(__name__)
@ -46,10 +50,10 @@ class Replicate(LLM):
serializable, is excluded from serialization.
"""
streaming: bool = Field(default=False)
streaming: bool = False
"""Whether to stream the results."""
stop: Optional[List[str]] = Field(default=[])
stop: List[str] = Field(default_factory=list)
"""Stop sequences to early-terminate generation."""
class Config:
@ -97,7 +101,7 @@ class Replicate(LLM):
"""Get the identifying parameters."""
return {
"model": self.model,
**{"model_kwargs": self.model_kwargs},
"model_kwargs": self.model_kwargs,
}
@property
@ -113,6 +117,63 @@ class Replicate(LLM):
**kwargs: Any,
) -> str:
"""Call to replicate endpoint."""
if self.streaming:
completion: Optional[str] = None
for chunk in self._stream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
if completion is None:
completion = chunk.text
else:
completion += chunk.text
else:
prediction = self._create_prediction(prompt, **kwargs)
prediction.wait()
if prediction.status == "failed":
raise RuntimeError(prediction.error)
completion = prediction.output
assert completion is not None
stop_conditions = stop or self.stop
for s in stop_conditions:
if s in completion:
completion = completion[: completion.find(s)]
return completion
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
prediction = self._create_prediction(prompt, **kwargs)
stop_conditions = stop or self.stop
stop_condition_reached = False
current_completion: str = ""
for output in prediction.output_iterator():
current_completion += output
# test for stop conditions, if specified
for s in stop_conditions:
if s in current_completion:
prediction.cancel()
stop_condition_reached = True
# Potentially some tokens that should still be yielded before ending
# stream.
stop_index = max(output.find(s), 0)
output = output[:stop_index]
if not output:
break
if output:
yield GenerationChunk(text=output)
if run_manager:
run_manager.on_llm_new_token(
output,
verbose=self.verbose,
)
if stop_condition_reached:
break
def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction:
try:
import replicate as replicate_python
except ImportError:
@ -138,29 +199,7 @@ class Replicate(LLM):
self.prompt_key = input_properties[0][0]
inputs: Dict = {self.prompt_key: prompt, **self.input}
prediction = replicate_python.predictions.create(
version=self.version_obj, input={**inputs, **kwargs}
input_: Dict = {self.prompt_key: prompt, **self.input, **kwargs}
return replicate_python.predictions.create(
version=self.version_obj, input=input_
)
current_completion: str = ""
stop_condition_reached = False
for output in prediction.output_iterator():
current_completion += output
# test for stop conditions, if specified
if stop:
for s in stop:
if s in current_completion:
prediction.cancel()
stop_index = current_completion.find(s)
current_completion = current_completion[:stop_index]
stop_condition_reached = True
break
if stop_condition_reached:
break
if self.streaming and run_manager:
run_manager.on_llm_new_token(output)
return current_completion

Loading…
Cancel
Save