|
|
|
@ -11,14 +11,10 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def tokenizer():
|
|
|
|
|
return transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def model():
|
|
|
|
|
return DistributedBloomForCausalLM.from_pretrained(
|
|
|
|
|
@pytest.mark.forked
|
|
|
|
|
def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
|
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
|
|
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
|
|
)
|
|
|
|
|
config = model.config
|
|
|
|
|