2023-07-12 12:22:28 +00:00
|
|
|
import peft
|
2022-07-19 01:28:04 +00:00
|
|
|
import pytest
|
2022-07-01 00:38:38 +00:00
|
|
|
import torch
|
|
|
|
import transformers
|
2022-12-15 05:12:18 +00:00
|
|
|
from hivemind import get_logger
|
2022-12-13 08:03:49 +00:00
|
|
|
from transformers.generation import BeamSearchScorer
|
|
|
|
from transformers.models.bloom import BloomForCausalLM
|
2022-07-01 00:38:38 +00:00
|
|
|
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
2023-06-23 11:46:10 +00:00
|
|
|
from petals import DistributedBloomForCausalLM
|
2023-03-12 21:49:04 +00:00
|
|
|
from test_utils import *
|
2022-07-01 00:38:38 +00:00
|
|
|
|
2023-02-19 01:46:17 +00:00
|
|
|
logger = get_logger(__name__)
|
2022-07-01 00:38:38 +00:00
|
|
|
|
|
|
|
|
2022-07-19 01:28:04 +00:00
|
|
|
@pytest.mark.forked
|
2023-07-12 12:22:28 +00:00
|
|
|
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
|
2022-12-13 08:03:49 +00:00
|
|
|
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
|
2023-07-12 12:22:28 +00:00
|
|
|
def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
|
2022-07-01 00:38:38 +00:00
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
2022-07-22 19:38:40 +00:00
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
2023-07-12 12:22:28 +00:00
|
|
|
MODEL_NAME,
|
|
|
|
initial_peers=INITIAL_PEERS,
|
|
|
|
low_cpu_mem_usage=True,
|
|
|
|
torch_dtype=torch.float32,
|
|
|
|
active_adapter=ADAPTER_NAME if use_peft else None,
|
2022-07-22 19:38:40 +00:00
|
|
|
)
|
2022-08-17 15:50:52 +00:00
|
|
|
config = model.config
|
2022-07-15 22:59:23 +00:00
|
|
|
assert isinstance(model, DistributedBloomForCausalLM)
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
2023-06-23 11:46:10 +00:00
|
|
|
assert len(model.transformer.h) == model.config.num_hidden_layers
|
2022-07-01 00:38:38 +00:00
|
|
|
|
2023-07-22 19:10:46 +00:00
|
|
|
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
|
2022-07-01 00:38:38 +00:00
|
|
|
|
2022-07-22 19:38:40 +00:00
|
|
|
with torch.inference_mode():
|
2022-07-19 01:28:04 +00:00
|
|
|
parallel_outputs = model.forward(test_inputs).logits
|
|
|
|
assert torch.all(torch.isfinite(parallel_outputs))
|
|
|
|
logger.info("Forward outputs are finite")
|
2022-07-01 00:38:38 +00:00
|
|
|
|
2022-07-15 22:59:23 +00:00
|
|
|
embs = model.transformer.word_embeddings(test_inputs)
|
|
|
|
embs = model.transformer.word_embeddings_layernorm(embs)
|
|
|
|
recurrent_outputs = []
|
2022-08-29 18:04:37 +00:00
|
|
|
with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
|
2022-12-13 08:03:49 +00:00
|
|
|
if pass_empty_tensors:
|
|
|
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
|
|
|
|
2022-07-15 22:59:23 +00:00
|
|
|
for t in range(embs.shape[1]):
|
2023-07-22 19:10:46 +00:00
|
|
|
if t == 4:
|
|
|
|
recurrent_outputs.append(sess.step(embs[:, 4:9, :]))
|
|
|
|
elif 4 < t < 9:
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
|
|
|
|
|
|
|
|
if t == 2 and pass_empty_tensors:
|
2022-12-13 08:03:49 +00:00
|
|
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
|
|
|
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
|
|
|
|
|
2022-07-15 22:59:23 +00:00
|
|
|
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
|
|
|
|
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
|
2022-07-22 19:38:40 +00:00
|
|
|
recurrent_outputs = model.lm_head(recurrent_outputs)
|
2022-07-19 01:28:04 +00:00
|
|
|
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
|
|
|
|
logger.info("Inference is consistent with forward")
|
|
|
|
|
2022-07-22 19:38:40 +00:00
|
|
|
del model, embs, recurrent_outputs
|
2022-07-19 01:28:04 +00:00
|
|
|
|
|
|
|
if REF_NAME:
|
2022-07-22 19:38:40 +00:00
|
|
|
ref_model = transformers.BloomForCausalLM.from_pretrained(
|
|
|
|
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
|
)
|
2023-07-12 12:22:28 +00:00
|
|
|
if use_peft:
|
|
|
|
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
|
|
|
|
ref_model.train(False)
|
2022-08-17 15:50:52 +00:00
|
|
|
if config.vocab_size < ref_model.config.vocab_size:
|
|
|
|
ref_model.resize_token_embeddings(config.vocab_size)
|
|
|
|
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
|
|
|
|
|
2022-07-19 01:28:04 +00:00
|
|
|
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
|
|
|
|
# note: this creates a dummy mask to make the test compatible with older transformer versions
|
|
|
|
# prior to https://github.com/huggingface/transformers/pull/17837
|
2022-07-22 19:38:40 +00:00
|
|
|
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
|
2022-07-19 01:28:04 +00:00
|
|
|
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
|
|
|
|
logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
|
|
|
|
del ref_model, ref_outputs, dummy_mask
|
|
|
|
else:
|
|
|
|
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
|
|
|
|
assert False
|
2022-07-27 06:19:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
|
def test_greedy_generation(max_new_tokens=4):
|
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
|
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
|
)
|
|
|
|
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
|
|
|
remote_outputs = model.generate(
|
|
|
|
inputs,
|
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
|
)
|
|
|
|
hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
|
2022-12-13 17:09:15 +00:00
|
|
|
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
|
2022-07-27 06:19:45 +00:00
|
|
|
|
|
|
|
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
|
|
|
|
"input_ids"
|
|
|
|
]
|
|
|
|
remote_outputs_batch = model.generate(
|
|
|
|
inputs_batch,
|
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
|
)
|
|
|
|
hf_outputs_batch = BloomForCausalLM.greedy_search(
|
|
|
|
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
|
|
|
|
)
|
|
|
|
assert torch.allclose(
|
|
|
|
remote_outputs_batch, hf_outputs_batch
|
2022-12-13 17:09:15 +00:00
|
|
|
), "Greedy search results are not identical to HF in multibatch mode"
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
|
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
|
|
|
|
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
|
|
|
|
def test_sampling(sampling_options, max_new_tokens=4):
|
|
|
|
torch.manual_seed(0)
|
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
|
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
|
)
|
|
|
|
logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options)
|
|
|
|
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
|
|
|
with torch.random.fork_rng():
|
|
|
|
remote_outputs = model.generate(
|
|
|
|
inputs,
|
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
|
do_sample=True,
|
|
|
|
**sampling_options,
|
|
|
|
)
|
|
|
|
with torch.random.fork_rng():
|
|
|
|
hf_outputs = BloomForCausalLM.sample(
|
|
|
|
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
|
|
|
|
)
|
|
|
|
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
|
|
|
|
|
|
|
|
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
|
|
|
|
"input_ids"
|
|
|
|
]
|
|
|
|
with torch.random.fork_rng():
|
|
|
|
remote_outputs_batch = model.generate(
|
|
|
|
inputs_batch,
|
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
|
do_sample=True,
|
|
|
|
**sampling_options,
|
|
|
|
)
|
|
|
|
with torch.random.fork_rng():
|
|
|
|
hf_outputs_batch = BloomForCausalLM.sample(
|
|
|
|
model,
|
|
|
|
input_ids=inputs_batch,
|
|
|
|
max_length=inputs_batch.size(1) + max_new_tokens,
|
|
|
|
logits_warper=logits_warper,
|
|
|
|
)
|
|
|
|
assert torch.allclose(
|
|
|
|
remote_outputs_batch, hf_outputs_batch
|
|
|
|
), "Sampling results are not identical to HF in multibatch mode"
|
2022-11-28 09:02:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
|
def test_beam_search_generation(max_new_tokens=4, num_beams=2):
|
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
|
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
|
)
|
|
|
|
text = "A cat sat on a mat"
|
|
|
|
inputs = tokenizer(text, return_tensors="pt")["input_ids"]
|
|
|
|
remote_outputs = model.generate(
|
|
|
|
inputs,
|
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
|
num_beams=num_beams,
|
|
|
|
)
|
|
|
|
beam_scorer = BeamSearchScorer(
|
|
|
|
batch_size=inputs.size(0),
|
|
|
|
num_beams=num_beams,
|
|
|
|
device=inputs.device,
|
|
|
|
length_penalty=0,
|
|
|
|
do_early_stopping=False,
|
|
|
|
)
|
|
|
|
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
|
|
|
|
hf_outputs = BloomForCausalLM.beam_search(
|
|
|
|
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
|
|
|
|
)
|
|
|
|
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"
|