Added stop sequence support to replicate (#8107)

Stop sequences are useful if you are doing long-running completions and
need to early-out rather than running for the full max_length... not
only does this save inference cost on Replicate, it is also much faster
if you are going to truncate the output later anyway.

Other LLMs support stop sequences natively (e.g. OpenAI) but I didn't
see this for Replicate so adding this via their prediction cancel
method.

Housekeeping: I ran `make format` and `make lint`, no issues reported in
the files I touched.

I did update the replicate integration test and ran `poetry run pytest
tests/integration_tests/llms/test_replicate.py` successfully.

Finally, I am @tjaffri https://twitter.com/tjaffri for feature
announcement tweets... or if you could please tag @docugami
https://twitter.com/docugami we would really appreciate that :-)

Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
Taqi Jaffri 2023-07-24 17:34:13 -07:00 committed by GitHub
parent f7ad14acfa
commit 8f158b72fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 152 additions and 61 deletions

File diff suppressed because one or more lines are too long

View File

@ -38,6 +38,9 @@ class Replicate(LLM):
streaming: bool = Field(default=False) streaming: bool = Field(default=False)
"""Whether to stream the results.""" """Whether to stream the results."""
stop: Optional[List[str]] = Field(default=[])
"""Stop sequences to early-terminate generation."""
class Config: class Config:
"""Configuration for this pydantic config.""" """Configuration for this pydantic config."""
@ -114,12 +117,27 @@ class Replicate(LLM):
first_input_name = input_properties[0][0] first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input} inputs = {first_input_name: prompt, **self.input}
iterator = replicate_python.run(self.model, input={**inputs, **kwargs}) prediction = replicate_python.predictions.create(
full_completion = "" version=version, input={**inputs, **kwargs}
for output in iterator:
full_completion += output
if self.streaming and run_manager:
run_manager.on_llm_new_token(
output,
) )
return full_completion current_completion: str = ""
stop_condition_reached = False
for output in prediction.output_iterator():
current_completion += output
# test for stop conditions, if specified
if stop:
for s in stop:
if s in current_completion:
prediction.cancel()
stop_index = current_completion.find(s)
current_completion = current_completion[:stop_index]
stop_condition_reached = True
break
if stop_condition_reached:
break
if self.streaming and run_manager:
run_manager.on_llm_new_token(output)
return current_completion

View File

@ -25,3 +25,10 @@ def test_replicate_streaming_call() -> None:
output = llm("LangChain") output = llm("LangChain")
assert output == "hello LangChain" assert output == "hello LangChain"
assert callback_handler.llm_streams == 15 assert callback_handler.llm_streams == 15
def test_replicate_stop_sequence() -> None:
"""Test call to Replicate with a stop sequence."""
llm = Replicate(model=TEST_MODEL)
output = llm("one two three", stop=["two"])
assert output == "hello one "