fix(petals) allows to run models that aren't Bloom (Support for LLama and newer models) (#8356)

In this PR:

- Removed restricted model loading logic for Petals-Bloom
- Removed petals imports (DistributedBloomForCausalLM,
BloomTokenizerFast)
- Instead imported more generalized versions of loader
(AutoDistributedModelForCausalLM, AutoTokenizer)
- Updated the Petals example notebook to allow for a successful
installation of Petals in Apple Silicon Macs

- Tag maintainer: @hwchase17, @baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/8395/head
Karan V 1 year ago committed by GitHub
parent e758e9e7f5
commit a003a0baf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,7 +16,9 @@
"metadata": {},
"source": [
"## Install petals\n",
"The `petals` package is required to use the Petals API. Install `petals` using `pip3 install petals`."
"The `petals` package is required to use the Petals API. Install `petals` using `pip3 install petals`.\n",
"\n",
"For Apple Silicon(M1/M2) users please follow this guide [https://github.com/bigscience-workshop/petals/issues/147#issuecomment-1365379642](https://github.com/bigscience-workshop/petals/issues/147#issuecomment-1365379642) to install petals "
]
},
{
@ -62,7 +64,7 @@
},
"outputs": [
{
"name": "stdin",
"name": "stdout",
"output_type": "stream",
"text": [
" ········\n"

@ -93,12 +93,14 @@ class Petals(LLM):
values, "huggingface_api_key", "HUGGINGFACE_API_KEY"
)
try:
from petals import DistributedBloomForCausalLM
from transformers import BloomTokenizerFast
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer
model_name = values["model_name"]
values["tokenizer"] = BloomTokenizerFast.from_pretrained(model_name)
values["client"] = DistributedBloomForCausalLM.from_pretrained(model_name)
values["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
values["client"] = AutoDistributedModelForCausalLM.from_pretrained(
model_name
)
values["huggingface_api_key"] = huggingface_api_key
except ImportError:

Loading…
Cancel
Save