rpc_inference works

inference_chain
justheuristic 2 years ago
parent 47c308306a
commit 33358bc52b

@ -18,7 +18,6 @@ class TransformerBackend(ModuleBackend):
super().__init__(*args, **kwargs)
assert isinstance(self.module, BloomBlock)
self.memory_cache = memory_cache
for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
for name, buf in self.module.named_buffers():
@ -28,11 +27,29 @@ class TransformerBackend(ModuleBackend):
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
attention_cache_handle = int(cache_metadata[0, 0].item())
current_sequence_length = int(cache_metadata[0, 1].item())
prefix_length = int(cache_metadata[0, 1].item())
hidden_states, *_ = inputs
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(attention_cache_handle) as cache:
print('METADATA:', cache_metadata, "CACHE", cache.mean(), "CACHE ENTRIES:", len(self.memory_cache._allocated_tensors))
cache[...] += 1
return (inputs[0] + cache.flatten()[0],)
print('METADATA:', cache_metadata)
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
print(past_k.shape, past_v.shape)
hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
# todo remove these debugprints
new_length = new_v.shape[1]
assert new_length > prefix_length
assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
assert torch.allclose(new_v[:, :past_v.shape[1]], past_v)
assert torch.allclose(new_k[:, :past_k.shape[1]], past_k)
cache[0, :, prefix_length: new_length, :] = new_k[:, prefix_length : new_length]
cache[1, :, prefix_length: new_length, :] = new_v[:, prefix_length: new_length]
return (hidden_states,)
def get_pools(self) -> Sequence[TaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool

@ -30,21 +30,22 @@ class TransformerConnectionHandler(ConnectionHandler):
assert isinstance(backend, TransformerBackend)
# prepare attention cache
hidden_size = backend.module.hidden_size
cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, current_sequence_length]
cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
current_sequence_length = 0
num_heads = backend.module.self_attention.num_heads
head_dim = backend.module.self_attention.head_dim
cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, prefix_length]
cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
prefix_length = 0
async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
while request.uid or request.tensors: # iterate while user is willing to supply tensors
inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
print("INPUTS:", inputs)
assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, current_sequence_length
cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
yield runtime_pb2.ExpertResponse(tensors=outputs)
current_sequence_length += inputs[1].shape[1]
prefix_length += inputs[1].shape[1]
request = await(anext(requests))
finally:
print("CLOSED RPC_INFERENCE")

Loading…
Cancel
Save