mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Added streaming support to Replicate (#8045)
Streaming support is useful if you are doing long-running completions or need interactivity e.g. for chat... adding it to replicate, using a similar pattern to other LLMs that support streaming. Housekeeping: I ran `make format` and `make lint`, no issues reported in the files I touched. I did update the replicate integration test but ran into some issues, specifically: 1. The original test was failing for me due to the model argument not being specified... perhaps this test is not regularly run? I fixed it by adding a call to the lightweight hello world model which should not be burdensome for replicate infra. 2. I couldn't get the `make integration_tests` command to pass... a lot of failures in other integration tests due to missing dependencies... however I did make sure the particluar test file I updated does pass, by running `poetry run pytest tests/integration_tests/llms/test_replicate.py` Finally, I am @tjaffri https://twitter.com/tjaffri for feature announcement tweets... or if you could please tag @docugami https://twitter.com/docugami we would really appreciate that :-) Tagging model maintainers @hwchase17 @baskaryan Thank for all the awesome work you folks are doing. --------- Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
This commit is contained in:
parent
31b7ddc12c
commit
973593c5c7
File diff suppressed because one or more lines are too long
@ -35,6 +35,9 @@ class Replicate(LLM):
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
replicate_api_token: Optional[str] = None
|
||||
|
||||
streaming: bool = Field(default=False)
|
||||
"""Whether to stream the results."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic config."""
|
||||
|
||||
@ -109,8 +112,14 @@ class Replicate(LLM):
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
)
|
||||
first_input_name = input_properties[0][0]
|
||||
|
||||
inputs = {first_input_name: prompt, **self.input}
|
||||
iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
|
||||
|
||||
return "".join([output for output in iterator])
|
||||
iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
|
||||
full_completion = ""
|
||||
for output in iterator:
|
||||
full_completion += output
|
||||
if self.streaming and run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
output,
|
||||
)
|
||||
return full_completion
|
||||
|
@ -1,10 +1,27 @@
|
||||
"""Test Replicate API wrapper."""
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.llms.replicate import Replicate
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
TEST_MODEL_NAME = "replicate/hello-world"
|
||||
TEST_MODEL_VER = "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"
|
||||
TEST_MODEL = TEST_MODEL_NAME + ":" + TEST_MODEL_VER
|
||||
|
||||
|
||||
def test_replicate_call() -> None:
|
||||
"""Test valid call to Replicate."""
|
||||
llm = Replicate()
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
"""Test simple non-streaming call to Replicate."""
|
||||
llm = Replicate(model=TEST_MODEL)
|
||||
output = llm("LangChain")
|
||||
assert output == "hello LangChain"
|
||||
|
||||
|
||||
def test_replicate_streaming_call() -> None:
|
||||
"""Test streaming call to Replicate."""
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
|
||||
llm = Replicate(streaming=True, callback_manager=callback_manager, model=TEST_MODEL)
|
||||
output = llm("LangChain")
|
||||
assert output == "hello LangChain"
|
||||
assert callback_handler.llm_streams == 15
|
||||
|
Loading…
Reference in New Issue
Block a user