Added streaming support to Replicate (#8045)

Streaming support is useful if you are doing long-running completions or
need interactivity e.g. for chat... adding it to replicate, using a
similar pattern to other LLMs that support streaming.

Housekeeping: I ran `make format` and `make lint`, no issues reported in
the files I touched.

I did update the replicate integration test but ran into some issues,
specifically:

1. The original test was failing for me due to the model argument not
being specified... perhaps this test is not regularly run? I fixed it by
adding a call to the lightweight hello world model which should not be
burdensome for replicate infra.
2. I couldn't get the `make integration_tests` command to pass... a lot
of failures in other integration tests due to missing dependencies...
however I did make sure the particluar test file I updated does pass, by
running `poetry run pytest
tests/integration_tests/llms/test_replicate.py`

Finally, I am @tjaffri https://twitter.com/tjaffri for feature
announcement tweets... or if you could please tag @docugami
https://twitter.com/docugami we would really appreciate that :-)

Tagging model maintainers @hwchase17  @baskaryan 

Thank for all the awesome work you folks are doing.

---------

Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
Taqi Jaffri 2023-07-20 18:59:54 -07:00 committed by GitHub
parent 31b7ddc12c
commit 973593c5c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 163 additions and 56 deletions

File diff suppressed because one or more lines are too long

View File

@ -35,6 +35,9 @@ class Replicate(LLM):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
replicate_api_token: Optional[str] = None
streaming: bool = Field(default=False)
"""Whether to stream the results."""
class Config:
"""Configuration for this pydantic config."""
@ -109,8 +112,14 @@ class Replicate(LLM):
key=lambda item: item[1].get("x-order", 0),
)
first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input}
iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
return "".join([output for output in iterator])
iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
full_completion = ""
for output in iterator:
full_completion += output
if self.streaming and run_manager:
run_manager.on_llm_new_token(
output,
)
return full_completion

View File

@ -1,10 +1,27 @@
"""Test Replicate API wrapper."""
from langchain.callbacks.manager import CallbackManager
from langchain.llms.replicate import Replicate
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
TEST_MODEL_NAME = "replicate/hello-world"
TEST_MODEL_VER = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
TEST_MODEL = TEST_MODEL_NAME + ":" + TEST_MODEL_VER
def test_replicate_call() -> None:
"""Test valid call to Replicate."""
llm = Replicate()
output = llm("Say foo:")
assert isinstance(output, str)
"""Test simple non-streaming call to Replicate."""
llm = Replicate(model=TEST_MODEL)
output = llm("LangChain")
assert output == "hello LangChain"
def test_replicate_streaming_call() -> None:
"""Test streaming call to Replicate."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = Replicate(streaming=True, callback_manager=callback_manager, model=TEST_MODEL)
output = llm("LangChain")
assert output == "hello LangChain"
assert callback_handler.llm_streams == 15