@ -3,29 +3,31 @@ import pytest
import torch
import transformers
from hivemind import get_logger
from transformers . generation import BeamSearchScorer
from transformers . models . bloom import BloomForCausalLM
from transformers . generation import BeamSearchScorer , GenerationMixin as HfGenerationMixin
from petals import DistributedBloom ForCausalLM
from petals import AutoDistributedModel ForCausalLM
from test_utils import *
logger = get_logger ( __name__ )
@pytest.fixture
def tokenizer ( ) :
# We set use_fast=False since LlamaTokenizerFast is slow on load
return transformers . AutoTokenizer . from_pretrained ( MODEL_NAME , use_fast = False )
@pytest.mark.forked
@pytest.mark.parametrize ( " use_peft " , ( True , False ) if ADAPTER_NAME else ( False , ) )
@pytest.mark.parametrize ( " pass_empty_tensors " , ( True , False ) )
def test_full_model_exact_match ( use_peft : bool , pass_empty_tensors : bool , atol_forward = 1e-3 , atol_inference = 1e-3 ) :
tokenizer = transformers . BloomTokenizerFast . from_pretrained ( MODEL_NAME )
model = DistributedBloomForCausalLM . from_pretrained (
def test_full_model_exact_match ( tokenizer , use_peft , pass_empty_tensors , atol_forward = 1e-3 , atol_inference = 1e-3 ) :
model = AutoDistributedModelForCausalLM . from_pretrained (
MODEL_NAME ,
initial_peers = INITIAL_PEERS ,
low_cpu_mem_usage = True ,
torch_dtype = torch . float32 ,
active_adapter = ADAPTER_NAME if use_peft else None ,
)
config = model . config
assert isinstance ( model , DistributedBloomForCausalLM )
assert len ( model . transformer . h ) == model . config . num_hidden_layers
test_inputs = tokenizer ( " A quick brown fox was minding its own buisness " , return_tensors = " pt " ) [ " input_ids " ]
@ -63,7 +65,7 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
del model , embs , recurrent_outputs
if REF_NAME :
ref_model = transformers . Bloom ForCausalLM. from_pretrained (
ref_model = transformers . AutoModel ForCausalLM. from_pretrained (
REF_NAME , low_cpu_mem_usage = True , torch_dtype = torch . float32
)
if use_peft :
@ -86,27 +88,29 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
@pytest.mark.forked
def test_greedy_generation ( max_new_tokens = 4 ) :
tokenizer = transformers . BloomTokenizerFast . from_pretrained ( MODEL_NAME )
model = DistributedBloomForCausalLM . from_pretrained (
MODEL_NAME , initial_peers = INITIAL_PEERS , low_cpu_mem_usage = True , torch_dtype = torch . float32
def test_greedy_generation ( tokenizer , max_new_tokens = 4 ) :
model = AutoDistributedModelForCausalLM . from_pretrained (
MODEL_NAME , initial_peers = INITIAL_PEERS , torch_dtype = torch . float32
)
inputs = tokenizer ( " A cat sat on a mat " , return_tensors = " pt " ) [ " input_ids " ]
remote_outputs = model . generate (
inputs ,
max_new_tokens = max_new_tokens ,
)
hf_outputs = BloomForCausalLM . greedy_search ( model , input_ids = inputs , max_length = inputs . size ( 1 ) + max_new_tokens )
hf_outputs = HfGenerationMixin . greedy_search ( model , input_ids = inputs , max_length = inputs . size ( 1 ) + max_new_tokens )
assert torch . allclose ( remote_outputs , hf_outputs ) , " Greedy search results are not identical to HF "
if tokenizer . pad_token_id is None :
tokenizer . pad_token_id = tokenizer . eos_token_id
inputs_batch = tokenizer ( [ " A cat sat on a mat " , " A dog sat on a mat " ] , return_tensors = " pt " , padding = True ) [
" input_ids "
]
remote_outputs_batch = model . generate (
inputs_batch ,
max_new_tokens = max_new_tokens ,
)
hf_outputs_batch = BloomForCausalLM . greedy_search (
hf_outputs_batch = HfGenerationMixin . greedy_search (
model , input_ids = inputs_batch , max_length = inputs_batch . size ( 1 ) + max_new_tokens
)
assert torch . allclose (
@ -117,13 +121,13 @@ def test_greedy_generation(max_new_tokens=4):
@pytest.mark.forked
@pytest.mark.parametrize ( " sampling_options " , [ dict ( ) , dict ( temperature = 100.0 ) , dict ( top_k = 5 ) , dict ( top_p = 0.9 ) ] )
@pytest.mark.skip ( " Sampling is currently not consistent with outputs from Transformers " )
def test_sampling ( sampling_options, max_new_tokens = 4 ) :
def test_sampling ( tokenizer, sampling_options, max_new_tokens = 4 ) :
torch . manual_seed ( 0 )
tokenizer = transformers . BloomTokenizerFast . from_pretrained ( MODEL_NAME )
model = DistributedBloom ForCausalLM. from_pretrained (
MODEL_NAME , initial_peers = INITIAL_PEERS , low_cpu_mem_usage= True , torch_dtype= torch . float32
model = AutoDistributedModel ForCausalLM. from_pretrained (
MODEL_NAME , initial_peers = INITIAL_PEERS , torch_dtype= torch . float32
)
logits_warper = BloomForCausalLM . _get_logits_warper ( model , num_beams = 1 , * * sampling_options )
logits_warper = HfGenerationMixin . _get_logits_warper ( model , num_beams = 1 , * * sampling_options )
inputs = tokenizer ( " A cat sat on a mat " , return_tensors = " pt " ) [ " input_ids " ]
with torch . random . fork_rng ( ) :
remote_outputs = model . generate (
@ -133,7 +137,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
* * sampling_options ,
)
with torch . random . fork_rng ( ) :
hf_outputs = BloomForCausalLM . sample (
hf_outputs = HfGenerationMixin . sample (
model , input_ids = inputs , max_length = inputs . size ( 1 ) + max_new_tokens , logits_warper = logits_warper
)
assert torch . allclose ( remote_outputs , hf_outputs ) , " Sampling results are not identical to HF "
@ -149,7 +153,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
* * sampling_options ,
)
with torch . random . fork_rng ( ) :
hf_outputs_batch = BloomForCausalLM . sample (
hf_outputs_batch = HfGenerationMixin . sample (
model ,
input_ids = inputs_batch ,
max_length = inputs_batch . size ( 1 ) + max_new_tokens ,
@ -161,10 +165,9 @@ def test_sampling(sampling_options, max_new_tokens=4):
@pytest.mark.forked
def test_beam_search_generation ( max_new_tokens = 4 , num_beams = 2 ) :
tokenizer = transformers . BloomTokenizerFast . from_pretrained ( MODEL_NAME )
model = DistributedBloomForCausalLM . from_pretrained (
MODEL_NAME , initial_peers = INITIAL_PEERS , low_cpu_mem_usage = True , torch_dtype = torch . float32
def test_beam_search_generation ( tokenizer , max_new_tokens = 4 , num_beams = 2 ) :
model = AutoDistributedModelForCausalLM . from_pretrained (
MODEL_NAME , initial_peers = INITIAL_PEERS , torch_dtype = torch . float32
)
text = " A cat sat on a mat "
inputs = tokenizer ( text , return_tensors = " pt " ) [ " input_ids " ]
@ -181,7 +184,7 @@ def test_beam_search_generation(max_new_tokens=4, num_beams=2):
do_early_stopping = False ,
)
hf_inputs = tokenizer ( [ text ] * 2 , return_tensors = " pt " ) [ " input_ids " ]
hf_outputs = BloomForCausalLM . beam_search (
hf_outputs = HfGenerationMixin . beam_search (
model , input_ids = hf_inputs , max_length = inputs . size ( 1 ) + max_new_tokens , beam_scorer = beam_scorer
)
assert torch . allclose ( remote_outputs , hf_outputs ) , " Beam search results are not identical to HF "