mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Added stop sequence support to replicate (#8107)
Stop sequences are useful if you are doing long-running completions and need to early-out rather than running for the full max_length... not only does this save inference cost on Replicate, it is also much faster if you are going to truncate the output later anyway. Other LLMs support stop sequences natively (e.g. OpenAI) but I didn't see this for Replicate so adding this via their prediction cancel method. Housekeeping: I ran `make format` and `make lint`, no issues reported in the files I touched. I did update the replicate integration test and ran `poetry run pytest tests/integration_tests/llms/test_replicate.py` successfully. Finally, I am @tjaffri https://twitter.com/tjaffri for feature announcement tweets... or if you could please tag @docugami https://twitter.com/docugami we would really appreciate that :-) Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
parent
f7ad14acfa
commit
8f158b72fc
File diff suppressed because one or more lines are too long
@ -38,6 +38,9 @@ class Replicate(LLM):
|
|||||||
streaming: bool = Field(default=False)
|
streaming: bool = Field(default=False)
|
||||||
"""Whether to stream the results."""
|
"""Whether to stream the results."""
|
||||||
|
|
||||||
|
stop: Optional[List[str]] = Field(default=[])
|
||||||
|
"""Stop sequences to early-terminate generation."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic config."""
|
"""Configuration for this pydantic config."""
|
||||||
|
|
||||||
@ -114,12 +117,27 @@ class Replicate(LLM):
|
|||||||
first_input_name = input_properties[0][0]
|
first_input_name = input_properties[0][0]
|
||||||
inputs = {first_input_name: prompt, **self.input}
|
inputs = {first_input_name: prompt, **self.input}
|
||||||
|
|
||||||
iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
|
prediction = replicate_python.predictions.create(
|
||||||
full_completion = ""
|
version=version, input={**inputs, **kwargs}
|
||||||
for output in iterator:
|
|
||||||
full_completion += output
|
|
||||||
if self.streaming and run_manager:
|
|
||||||
run_manager.on_llm_new_token(
|
|
||||||
output,
|
|
||||||
)
|
)
|
||||||
return full_completion
|
current_completion: str = ""
|
||||||
|
stop_condition_reached = False
|
||||||
|
for output in prediction.output_iterator():
|
||||||
|
current_completion += output
|
||||||
|
|
||||||
|
# test for stop conditions, if specified
|
||||||
|
if stop:
|
||||||
|
for s in stop:
|
||||||
|
if s in current_completion:
|
||||||
|
prediction.cancel()
|
||||||
|
stop_index = current_completion.find(s)
|
||||||
|
current_completion = current_completion[:stop_index]
|
||||||
|
stop_condition_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if stop_condition_reached:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.streaming and run_manager:
|
||||||
|
run_manager.on_llm_new_token(output)
|
||||||
|
return current_completion
|
||||||
|
@ -25,3 +25,10 @@ def test_replicate_streaming_call() -> None:
|
|||||||
output = llm("LangChain")
|
output = llm("LangChain")
|
||||||
assert output == "hello LangChain"
|
assert output == "hello LangChain"
|
||||||
assert callback_handler.llm_streams == 15
|
assert callback_handler.llm_streams == 15
|
||||||
|
|
||||||
|
|
||||||
|
def test_replicate_stop_sequence() -> None:
|
||||||
|
"""Test call to Replicate with a stop sequence."""
|
||||||
|
llm = Replicate(model=TEST_MODEL)
|
||||||
|
output = llm("one two three", stop=["two"])
|
||||||
|
assert output == "hello one "
|
||||||
|
Loading…
Reference in New Issue
Block a user