community: update Replicate to work with official models (#20633)

Description: you don't need to pass a version for Replicate official
models. That was broken on LangChain until now!

You can now run: 

```
llm = Replicate(
    model="meta/meta-llama-3-8b-instruct",
    model_kwargs={"temperature": 0.75, "max_length": 500, "top_p": 1},
)
prompt = """
User: Answer the following yes/no question by reasoning step by step. Can a dog drive a car?
Assistant:
"""
llm(prompt)
```

I've updated the replicate.ipynb to reflect that.

twitter: @charliebholtz

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Charlie Holtz 2024-04-18 18:43:40 -07:00 committed by GitHub
parent dd5139e304
commit 1cbab0ebda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 28 deletions

View File

@ -38,7 +38,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 4,
"metadata": { "metadata": {
"scrolled": true, "scrolled": true,
"tags": [] "tags": []
@ -49,17 +49,21 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Collecting replicate\n", "Collecting replicate\n",
" Using cached replicate-0.9.0-py3-none-any.whl (21 kB)\n", " Using cached replicate-0.25.1-py3-none-any.whl.metadata (24 kB)\n",
"Requirement already satisfied: packaging in /root/Source/github/docugami.langchain/libs/langchain/.venv/lib/python3.9/site-packages (from replicate) (23.1)\n", "Requirement already satisfied: httpx<1,>=0.21.0 in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from replicate) (0.24.1)\n",
"Requirement already satisfied: pydantic>1 in /root/Source/github/docugami.langchain/libs/langchain/.venv/lib/python3.9/site-packages (from replicate) (1.10.9)\n", "Requirement already satisfied: packaging in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from replicate) (23.2)\n",
"Requirement already satisfied: requests>2 in /root/Source/github/docugami.langchain/libs/langchain/.venv/lib/python3.9/site-packages (from replicate) (2.28.2)\n", "Requirement already satisfied: pydantic>1.10.7 in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from replicate) (1.10.14)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in /root/Source/github/docugami.langchain/libs/langchain/.venv/lib/python3.9/site-packages (from pydantic>1->replicate) (4.5.0)\n", "Requirement already satisfied: typing-extensions>=4.5.0 in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from replicate) (4.10.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /root/Source/github/docugami.langchain/libs/langchain/.venv/lib/python3.9/site-packages (from requests>2->replicate) (3.1.0)\n", "Requirement already satisfied: certifi in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from httpx<1,>=0.21.0->replicate) (2024.2.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /root/Source/github/docugami.langchain/libs/langchain/.venv/lib/python3.9/site-packages (from requests>2->replicate) (3.4)\n", "Requirement already satisfied: httpcore<0.18.0,>=0.15.0 in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from httpx<1,>=0.21.0->replicate) (0.17.3)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /root/Source/github/docugami.langchain/libs/langchain/.venv/lib/python3.9/site-packages (from requests>2->replicate) (1.26.16)\n", "Requirement already satisfied: idna in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from httpx<1,>=0.21.0->replicate) (3.6)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /root/Source/github/docugami.langchain/libs/langchain/.venv/lib/python3.9/site-packages (from requests>2->replicate) (2023.5.7)\n", "Requirement already satisfied: sniffio in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from httpx<1,>=0.21.0->replicate) (1.3.1)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from httpcore<0.18.0,>=0.15.0->httpx<1,>=0.21.0->replicate) (0.14.0)\n",
"Requirement already satisfied: anyio<5.0,>=3.0 in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from httpcore<0.18.0,>=0.15.0->httpx<1,>=0.21.0->replicate) (3.7.1)\n",
"Requirement already satisfied: exceptiongroup in /Users/charlieholtz/miniconda3/envs/langchain/lib/python3.9/site-packages (from anyio<5.0,>=3.0->httpcore<0.18.0,>=0.15.0->httpx<1,>=0.21.0->replicate) (1.2.0)\n",
"Using cached replicate-0.25.1-py3-none-any.whl (39 kB)\n",
"Installing collected packages: replicate\n", "Installing collected packages: replicate\n",
"Successfully installed replicate-0.9.0\n" "Successfully installed replicate-0.25.1\n"
] ]
} }
], ],
@ -69,7 +73,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 5,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@ -84,7 +88,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 6,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@ -97,7 +101,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 56,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@ -116,28 +120,28 @@
"\n", "\n",
"Find a model on the [replicate explore page](https://replicate.com/explore), and then paste in the model name and version in this format: model_name/version.\n", "Find a model on the [replicate explore page](https://replicate.com/explore), and then paste in the model name and version in this format: model_name/version.\n",
"\n", "\n",
"For example, here is [`LLama-V2`](https://replicate.com/a16z-infra/llama13b-v2-chat)." "For example, here is [`Meta Llama 3`](https://replicate.com/meta/meta-llama-3-8b-instruct)."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 58,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'1. Dogs do not have the ability to operate complex machinery like cars.\\n2. Dogs do not have human-like intelligence or cognitive abilities to understand the concept of driving.\\n3. Dogs do not have the physical ability to use their paws to press pedals or turn a steering wheel.\\n4. Therefore, a dog cannot drive a car.'" "\"Let's break this down step by step:\\n\\n1. A dog is a living being, specifically a mammal.\\n2. Dogs do not possess the cognitive abilities or physical characteristics necessary to operate a vehicle, such as a car.\\n3. Operating a car requires complex mental and physical abilities, including:\\n\\t* Understanding of traffic laws and rules\\n\\t* Ability to read and comprehend road signs\\n\\t* Ability to make decisions quickly and accurately\\n\\t* Ability to physically manipulate the vehicle's controls (e.g., steering wheel, pedals)\\n4. Dogs do not possess any of these abilities. They are unable to read or comprehend written language, let alone complex traffic laws.\\n5. Dogs also lack the physical dexterity and coordination to operate a vehicle's controls. Their paws and claws are not adapted for grasping or manipulating small, precise objects like a steering wheel or pedals.\\n6. Therefore, it is not possible for a dog to drive a car.\\n\\nAnswer: No.\""
] ]
}, },
"execution_count": 19, "execution_count": 58,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"llm = Replicate(\n", "llm = Replicate(\n",
" model=\"a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5\",\n", " model=\"meta/meta-llama-3-8b-instruct\",\n",
" model_kwargs={\"temperature\": 0.75, \"max_length\": 500, \"top_p\": 1},\n", " model_kwargs={\"temperature\": 0.75, \"max_length\": 500, \"top_p\": 1},\n",
")\n", ")\n",
"prompt = \"\"\"\n", "prompt = \"\"\"\n",
@ -195,7 +199,7 @@
], ],
"source": [ "source": [
"prompt = \"\"\"\n", "prompt = \"\"\"\n",
"Answer the following yes/no question by reasoning step by step. \n", "Answer the following yes/no question by reasoning step by step.\n",
"Can a dog drive a car?\n", "Can a dog drive a car?\n",
"\"\"\"\n", "\"\"\"\n",
"llm(prompt)" "llm(prompt)"
@ -554,7 +558,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.9.19"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {

View File

@ -44,7 +44,7 @@ class Replicate(LLM):
replicate_api_token: Optional[str] = None replicate_api_token: Optional[str] = None
prompt_key: Optional[str] = None prompt_key: Optional[str] = None
version_obj: Any = Field(default=None, exclude=True) version_obj: Any = Field(default=None, exclude=True)
"""Optionally pass in the model version object during initialization to avoid """Optionally pass in the model version object during initialization to avoid
having to make an extra API call to retrieve it during streaming. NOTE: not having to make an extra API call to retrieve it during streaming. NOTE: not
serializable, is excluded from serialization. serializable, is excluded from serialization.
""" """
@ -197,9 +197,13 @@ class Replicate(LLM):
# get the model and version # get the model and version
if self.version_obj is None: if self.version_obj is None:
model_str, version_str = self.model.split(":") if ":" in self.model:
model = replicate_python.models.get(model_str) model_str, version_str = self.model.split(":")
self.version_obj = model.versions.get(version_str) model = replicate_python.models.get(model_str)
self.version_obj = model.versions.get(version_str)
else:
model = replicate_python.models.get(self.model)
self.version_obj = model.latest_version
if self.prompt_key is None: if self.prompt_key is None:
# sort through the openapi schema to get the name of the first input # sort through the openapi schema to get the name of the first input
@ -217,6 +221,11 @@ class Replicate(LLM):
**self.model_kwargs, **self.model_kwargs,
**kwargs, **kwargs,
} }
return replicate_python.predictions.create(
version=self.version_obj, input=input_ # if it's an official model
) if ":" not in self.model:
return replicate_python.models.predictions.create(self.model, input=input_)
else:
return replicate_python.predictions.create(
version=self.version_obj, input=input_
)