finalize diff compression

diff-compression
justheuristic 2 years ago
parent 2152036441
commit 984adc5b3f

@ -105,14 +105,14 @@ class _ServerInferenceSession:
# serialize inputs and put them into the queue
inputs = (new_hidden_states, prompts, hypo_ids)
flat_inference_schema = nested_flatten(self.rpc_info["inference_schema"])
serialized_inputs = tuple(serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, flat_inference_schema))
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
],
tensors=serialized_inputs,
metadata=self._serialized_metadata if not self.stepped else None,
)
)
@ -120,7 +120,9 @@ class _ServerInferenceSession:
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
# add back residual connections after rpc_inference
return outputs[0].add_(new_hidden_states)
inputs_are_compressed = flat_inference_schema[0].compression != runtime_pb2.CompressionType.NONE
residuals = deserialize_torch_tensor(serialized_inputs[0]) if inputs_are_compressed else new_hidden_states
return outputs[0].add_(residuals)
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""

@ -38,10 +38,10 @@ def test_remote_sequential():
assert hidden.shape == test_inputs.shape
assert hidden.requires_grad
second_half_outputs = second_half(hidden)
assert torch.allclose(second_half_outputs, full_outputs)
assert torch.allclose(second_half_outputs, full_outputs, rtol=0, atol=1e-4)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad)
assert torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=3e-4)
@pytest.mark.forked
@ -79,11 +79,12 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
(outputs_ref,) = block(outputs_ref)
outputs_ref = (outputs_ref - torch.cat([inputs, input_prompts_ref], dim=1)) + torch.cat([inputs, input_prompts_ref], dim=1)
assert torch.allclose(outputs_ref, outputs)
assert torch.allclose(outputs_ref, outputs) # exact match
(outputs_ref * output_proj).sum().backward()
assert input_prompts_ref.grad is not None
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, rtol=0, atol=1e-5)
assert intermediate_prompts_ref.grad is not None
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, rtol=0, atol=1e-5)

Loading…
Cancel
Save