Added stop sequence support to replicate (#8107)

Stop sequences are useful if you are doing long-running completions and
need to early-out rather than running for the full max_length... not
only does this save inference cost on Replicate, it is also much faster
if you are going to truncate the output later anyway.

Other LLMs support stop sequences natively (e.g. OpenAI) but I didn't
see this for Replicate so adding this via their prediction cancel
method.

Housekeeping: I ran `make format` and `make lint`, no issues reported in
the files I touched.

I did update the replicate integration test and ran `poetry run pytest
tests/integration_tests/llms/test_replicate.py` successfully.

Finally, I am @tjaffri https://twitter.com/tjaffri for feature
announcement tweets... or if you could please tag @docugami
https://twitter.com/docugami we would really appreciate that :-)

Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
pull/8206/head^2
Taqi Jaffri 1 year ago committed by GitHub
parent f7ad14acfa
commit 8f158b72fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because one or more lines are too long

@ -38,6 +38,9 @@ class Replicate(LLM):
streaming: bool = Field(default=False)
"""Whether to stream the results."""
stop: Optional[List[str]] = Field(default=[])
"""Stop sequences to early-terminate generation."""
class Config:
"""Configuration for this pydantic config."""
@ -114,12 +117,27 @@ class Replicate(LLM):
first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input}
iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
full_completion = ""
for output in iterator:
full_completion += output
prediction = replicate_python.predictions.create(
version=version, input={**inputs, **kwargs}
)
current_completion: str = ""
stop_condition_reached = False
for output in prediction.output_iterator():
current_completion += output
# test for stop conditions, if specified
if stop:
for s in stop:
if s in current_completion:
prediction.cancel()
stop_index = current_completion.find(s)
current_completion = current_completion[:stop_index]
stop_condition_reached = True
break
if stop_condition_reached:
break
if self.streaming and run_manager:
run_manager.on_llm_new_token(
output,
)
return full_completion
run_manager.on_llm_new_token(output)
return current_completion

@ -25,3 +25,10 @@ def test_replicate_streaming_call() -> None:
output = llm("LangChain")
assert output == "hello LangChain"
assert callback_handler.llm_streams == 15
def test_replicate_stop_sequence() -> None:
"""Test call to Replicate with a stop sequence."""
llm = Replicate(model=TEST_MODEL)
output = llm("one two three", stop=["two"])
assert output == "hello one "

Loading…
Cancel
Save