Fix Replicate llm response to handle iterator / multiple outputs (#3614)

One of our users noticed a bug when calling streaming models. This is
because those models return an iterator. So, I've updated the Replicate
`_call` code to join together the output. The other advantage of this
fix is that if you requested multiple outputs you would get them all –
previously I was just returning output[0].

I also adjusted the demo docs to use dolly, because we're featuring that
model right now and it's always hot, so people won't have to wait for
the model to boot up.

The error that this fixes:
```
> llm = Replicate(model=“replicate/flan-t5-xl:eec2f71c986dfa3b7a5d842d22e1130550f015720966bec48beaae059b19ef4c”)
>  llm(“hello”)
> Traceback (most recent call last):
  File "/Users/charlieholtz/workspace/dev/python/main.py", line 15, in <module>
    print(llm(prompt))
  File "/opt/homebrew/lib/python3.10/site-packages/langchain/llms/base.py", line 246, in __call__
    return self.generate([prompt], stop=stop).generations[0][0].text
  File "/opt/homebrew/lib/python3.10/site-packages/langchain/llms/base.py", line 140, in generate
    raise e
  File "/opt/homebrew/lib/python3.10/site-packages/langchain/llms/base.py", line 137, in generate
    output = self._generate(prompts, stop=stop)
  File "/opt/homebrew/lib/python3.10/site-packages/langchain/llms/base.py", line 324, in _generate
    text = self._call(prompt, stop=stop)
  File "/opt/homebrew/lib/python3.10/site-packages/langchain/llms/replicate.py", line 108, in _call
    return outputs[0]
TypeError: 'generator' object is not subscriptable
```
This commit is contained in:
Charlie Holtz 2023-04-26 17:26:33 -04:00 committed by GitHub
parent 7536912125
commit 246710def9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 12 deletions

View File

@ -9,7 +9,7 @@ This page covers how to run models on Replicate within LangChain.
Find a model on the [Replicate explore page](https://replicate.com/explore), and then paste in the model name and version in this format: `owner-name/model-name:version` Find a model on the [Replicate explore page](https://replicate.com/explore), and then paste in the model name and version in this format: `owner-name/model-name:version`
For example, for this [flan-t5 model](https://replicate.com/daanelson/flan-t5), click on the API tab. The model name/version would be: `daanelson/flan-t5:04e422a9b85baed86a4f24981d7f9953e20c5fd82f6103b74ebc431588e1cec8` For example, for this [dolly model](https://replicate.com/replicate/dolly-v2-12b), click on the API tab. The model name/version would be: `"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5"`
Only the `model` param is required, but any other model parameters can also be passed in with the format `input={model_param: value, ...}` Only the `model` param is required, but any other model parameters can also be passed in with the format `input={model_param: value, ...}`
@ -24,7 +24,7 @@ Replicate(model="stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6
From here, we can initialize our model: From here, we can initialize our model:
```python ```python
llm = Replicate(model="daanelson/flan-t5:04e422a9b85baed86a4f24981d7f9953e20c5fd82f6103b74ebc431588e1cec8") llm = Replicate(model="replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5")
``` ```
And run it: And run it:
@ -40,8 +40,7 @@ llm(prompt)
We can call any Replicate model (not just LLMs) using this syntax. For example, we can call [Stable Diffusion](https://replicate.com/stability-ai/stable-diffusion): We can call any Replicate model (not just LLMs) using this syntax. For example, we can call [Stable Diffusion](https://replicate.com/stability-ai/stable-diffusion):
```python ```python
text2image = Replicate(model="stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf", text2image = Replicate(model="stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf", input={'image_dimensions':'512x512'})
input={'image_dimensions'='512x512'}
image_output = text2image("A cat riding a motorcycle by Picasso") image_output = text2image("A cat riding a motorcycle by Picasso")
``` ```

View File

@ -44,7 +44,7 @@
}, },
"outputs": [ "outputs": [
{ {
"name": "stdin", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
" ········\n" " ········\n"
@ -85,6 +85,7 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
@ -92,7 +93,7 @@
"\n", "\n",
"Find a model on the [replicate explore page](https://replicate.com/explore), and then paste in the model name and version in this format: model_name/version\n", "Find a model on the [replicate explore page](https://replicate.com/explore), and then paste in the model name and version in this format: model_name/version\n",
"\n", "\n",
"For example, for this [flan-t5 model]( https://replicate.com/daanelson/flan-t5), click on the API tab. The model name/version would be: `daanelson/flan-t5:04e422a9b85baed86a4f24981d7f9953e20c5fd82f6103b74ebc431588e1cec8`\n", "For example, for this [dolly model](https://replicate.com/replicate/dolly-v2-12b), click on the API tab. The model name/version would be: `replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5`\n",
"\n", "\n",
"Only the `model` param is required, but we can add other model params when initializing.\n", "Only the `model` param is required, but we can add other model params when initializing.\n",
"\n", "\n",
@ -113,7 +114,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"llm = Replicate(model=\"daanelson/flan-t5:04e422a9b85baed86a4f24981d7f9953e20c5fd82f6103b74ebc431588e1cec8\")" "llm = Replicate(model=\"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5\")"
] ]
}, },
{ {
@ -243,7 +244,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"llm = Replicate(model=\"daanelson/flan-t5:04e422a9b85baed86a4f24981d7f9953e20c5fd82f6103b74ebc431588e1cec8\")\n", "dolly_llm = Replicate(model=\"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5\")\n",
"text2image = Replicate(model=\"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf\")" "text2image = Replicate(model=\"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf\")"
] ]
}, },
@ -265,7 +266,7 @@
" template=\"What is a good name for a company that makes {product}?\",\n", " template=\"What is a good name for a company that makes {product}?\",\n",
")\n", ")\n",
"\n", "\n",
"chain = LLMChain(llm=llm, prompt=prompt)" "chain = LLMChain(llm=dolly_llm, prompt=prompt)"
] ]
}, },
{ {
@ -285,7 +286,7 @@
" input_variables=[\"company_name\"],\n", " input_variables=[\"company_name\"],\n",
" template=\"Write a description of a logo for this company: {company_name}\",\n", " template=\"Write a description of a logo for this company: {company_name}\",\n",
")\n", ")\n",
"chain_two = LLMChain(llm=llm, prompt=second_prompt)" "chain_two = LLMChain(llm=dolly_llm, prompt=second_prompt)"
] ]
}, },
{ {

View File

@ -103,6 +103,6 @@ class Replicate(LLM):
first_input_name = input_properties[0][0] first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input} inputs = {first_input_name: prompt, **self.input}
iterator = replicate_python.run(self.model, input={**inputs})
outputs = replicate_python.run(self.model, input={**inputs}) return "".join([output for output in iterator])
return outputs[0]