From 246710def93c891e5653876ceb3265bdc23ab7e9 Mon Sep 17 00:00:00 2001 From: Charlie Holtz Date: Wed, 26 Apr 2023 17:26:33 -0400 Subject: [PATCH] Fix Replicate llm response to handle iterator / multiple outputs (#3614) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit One of our users noticed a bug when calling streaming models. This is because those models return an iterator. So, I've updated the Replicate `_call` code to join together the output. The other advantage of this fix is that if you requested multiple outputs you would get them all – previously I was just returning output[0]. I also adjusted the demo docs to use dolly, because we're featuring that model right now and it's always hot, so people won't have to wait for the model to boot up. The error that this fixes: ``` > llm = Replicate(model=“replicate/flan-t5-xl:eec2f71c986dfa3b7a5d842d22e1130550f015720966bec48beaae059b19ef4c”) > llm(“hello”) > Traceback (most recent call last): File "/Users/charlieholtz/workspace/dev/python/main.py", line 15, in print(llm(prompt)) File "/opt/homebrew/lib/python3.10/site-packages/langchain/llms/base.py", line 246, in __call__ return self.generate([prompt], stop=stop).generations[0][0].text File "/opt/homebrew/lib/python3.10/site-packages/langchain/llms/base.py", line 140, in generate raise e File "/opt/homebrew/lib/python3.10/site-packages/langchain/llms/base.py", line 137, in generate output = self._generate(prompts, stop=stop) File "/opt/homebrew/lib/python3.10/site-packages/langchain/llms/base.py", line 324, in _generate text = self._call(prompt, stop=stop) File "/opt/homebrew/lib/python3.10/site-packages/langchain/llms/replicate.py", line 108, in _call return outputs[0] TypeError: 'generator' object is not subscriptable ``` --- docs/ecosystem/replicate.md | 7 +++---- .../models/llms/integrations/replicate.ipynb | 13 +++++++------ langchain/llms/replicate.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/ecosystem/replicate.md b/docs/ecosystem/replicate.md index e9b604ba..21bd1925 100644 --- a/docs/ecosystem/replicate.md +++ b/docs/ecosystem/replicate.md @@ -9,7 +9,7 @@ This page covers how to run models on Replicate within LangChain. Find a model on the [Replicate explore page](https://replicate.com/explore), and then paste in the model name and version in this format: `owner-name/model-name:version` -For example, for this [flan-t5 model](https://replicate.com/daanelson/flan-t5), click on the API tab. The model name/version would be: `daanelson/flan-t5:04e422a9b85baed86a4f24981d7f9953e20c5fd82f6103b74ebc431588e1cec8` +For example, for this [dolly model](https://replicate.com/replicate/dolly-v2-12b), click on the API tab. The model name/version would be: `"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5"` Only the `model` param is required, but any other model parameters can also be passed in with the format `input={model_param: value, ...}` @@ -24,7 +24,7 @@ Replicate(model="stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6 From here, we can initialize our model: ```python -llm = Replicate(model="daanelson/flan-t5:04e422a9b85baed86a4f24981d7f9953e20c5fd82f6103b74ebc431588e1cec8") +llm = Replicate(model="replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5") ``` And run it: @@ -40,8 +40,7 @@ llm(prompt) We can call any Replicate model (not just LLMs) using this syntax. For example, we can call [Stable Diffusion](https://replicate.com/stability-ai/stable-diffusion): ```python -text2image = Replicate(model="stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf", - input={'image_dimensions'='512x512'} +text2image = Replicate(model="stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf", input={'image_dimensions':'512x512'}) image_output = text2image("A cat riding a motorcycle by Picasso") ``` diff --git a/docs/modules/models/llms/integrations/replicate.ipynb b/docs/modules/models/llms/integrations/replicate.ipynb index 0607f77a..5ef5af40 100644 --- a/docs/modules/models/llms/integrations/replicate.ipynb +++ b/docs/modules/models/llms/integrations/replicate.ipynb @@ -44,7 +44,7 @@ }, "outputs": [ { - "name": "stdin", + "name": "stdout", "output_type": "stream", "text": [ " ········\n" @@ -85,6 +85,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -92,7 +93,7 @@ "\n", "Find a model on the [replicate explore page](https://replicate.com/explore), and then paste in the model name and version in this format: model_name/version\n", "\n", - "For example, for this [flan-t5 model]( https://replicate.com/daanelson/flan-t5), click on the API tab. The model name/version would be: `daanelson/flan-t5:04e422a9b85baed86a4f24981d7f9953e20c5fd82f6103b74ebc431588e1cec8`\n", + "For example, for this [dolly model](https://replicate.com/replicate/dolly-v2-12b), click on the API tab. The model name/version would be: `replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5`\n", "\n", "Only the `model` param is required, but we can add other model params when initializing.\n", "\n", @@ -113,7 +114,7 @@ }, "outputs": [], "source": [ - "llm = Replicate(model=\"daanelson/flan-t5:04e422a9b85baed86a4f24981d7f9953e20c5fd82f6103b74ebc431588e1cec8\")" + "llm = Replicate(model=\"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5\")" ] }, { @@ -243,7 +244,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm = Replicate(model=\"daanelson/flan-t5:04e422a9b85baed86a4f24981d7f9953e20c5fd82f6103b74ebc431588e1cec8\")\n", + "dolly_llm = Replicate(model=\"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5\")\n", "text2image = Replicate(model=\"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf\")" ] }, @@ -265,7 +266,7 @@ " template=\"What is a good name for a company that makes {product}?\",\n", ")\n", "\n", - "chain = LLMChain(llm=llm, prompt=prompt)" + "chain = LLMChain(llm=dolly_llm, prompt=prompt)" ] }, { @@ -285,7 +286,7 @@ " input_variables=[\"company_name\"],\n", " template=\"Write a description of a logo for this company: {company_name}\",\n", ")\n", - "chain_two = LLMChain(llm=llm, prompt=second_prompt)" + "chain_two = LLMChain(llm=dolly_llm, prompt=second_prompt)" ] }, { diff --git a/langchain/llms/replicate.py b/langchain/llms/replicate.py index 42213a49..6b487230 100644 --- a/langchain/llms/replicate.py +++ b/langchain/llms/replicate.py @@ -103,6 +103,6 @@ class Replicate(LLM): first_input_name = input_properties[0][0] inputs = {first_input_name: prompt, **self.input} + iterator = replicate_python.run(self.model, input={**inputs}) - outputs = replicate_python.run(self.model, input={**inputs}) - return outputs[0] + return "".join([output for output in iterator])