Fixes and test

pull/206/head
artek0chumak 1 year ago committed by Artem Chumachenko
parent a22ecc524d
commit d14debea35

@ -107,9 +107,6 @@ class _ServerInferenceSession:
if attention_mask is None:
attention_mask = DUMMY
if attention_mask is None:
attention_mask = DUMMY
# serialize inputs and put them into the queue
inputs = (new_hidden_states, attention_mask, prompts, hypo_ids)
outputs_serialized = RemoteExpertWorker.run_coroutine(
@ -239,9 +236,6 @@ class InferenceSession:
if attention_mask is None:
attention_mask = DUMMY
if attention_mask is None:
attention_mask = DUMMY
inputs_device = inputs.device
inputs_dtype = inputs.dtype
inputs = inputs.cpu()

@ -192,9 +192,6 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
if attention_mask is None:
attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)
if attention_mask is None:
attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)

@ -333,7 +333,7 @@ class _SequenceManagerUpdateThread(threading.Thread):
def maybe_log_traceback(exc: Exception):
traceback_level = logging.DEBUG if str(exc) or isinstance(exc, asyncio.TimeoutError) else logging.WARNING
logger.log(logging.INFO, "See detailed traceback below:", exc_info=True)
logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
class MissingBlocksError(RuntimeError):

@ -84,7 +84,7 @@ class TransformerBackend(ModuleBackend):
def inference_step(
self,
hidden_states: torch.Tensor,
attention_masks: torch.Tensor,
attention_mask: torch.Tensor,
hypo_ids: torch.LongTensor,
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:

@ -13,7 +13,8 @@ logger = get_logger(__file__)
@pytest.mark.forked
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
@pytest.mark.parametrize("second_token_attention_mask", (1, 0))
def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
@ -23,9 +24,11 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
assert len(model.transformer.h) == model.config.n_layer
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
attention_mask = torch.ones_like(test_inputs)
attention_mask[0, 1] = second_token_attention_mask
with torch.inference_mode():
parallel_outputs = model.forward(test_inputs).logits
parallel_outputs = model.forward(test_inputs, attention_mask=attention_mask).logits
assert torch.all(torch.isfinite(parallel_outputs))
logger.info("Forward outputs are finite")
@ -37,7 +40,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
for t in range(embs.shape[1]):
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, :t+1]))
if t == int(embs.shape[1] // 2) and pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
@ -58,13 +61,10 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
ref_model.resize_token_embeddings(config.vocab_size)
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
ref_outputs = ref_model.forward(test_inputs, attention_mask=attention_mask).logits.float()
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
del ref_model, ref_outputs, dummy_mask
del ref_model, ref_outputs
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
assert False

@ -71,7 +71,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
rpc_info = super().rpc_info
dims = (2048, 1024)
compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16)
rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs)
rpc_info["forward_schema"] = (compressed_input_schema, compressed_input_schema), dict() # (args, kwargs)
return rpc_info
def get_request_metadata(self, protocol: str, *args, **kwargs):

Loading…
Cancel
Save