basic cache checks (via debugprint)

inference_chain
justheuristic 2 years ago
parent fee63bd440
commit a44cb84f06

@ -26,14 +26,14 @@ class TransformerBackend(ModuleBackend):
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
def inference_step(self, attention_cache_handle: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
attention_cache_handle = int(attention_cache_handle.item())
print('HANDLE:', attention_cache_handle)
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
attention_cache_handle = int(cache_metadata[0, 0].item())
current_sequence_length = int(cache_metadata[0, 1].item())
with self.memory_cache.use_cache(attention_cache_handle) as cache:
print('METADATA:', cache_metadata, "CACHE ENTRIES:", len(self.memory_cache._allocated_tensors))
print(inputs[0].shape, cache.shape)
cache[...] += 1
return (inputs[0] + cache,)
return (inputs[0] + cache.flatten()[0],)
def get_pools(self) -> Sequence[TaskPool]:
return self.forward_pool, self.backward_pool, self.inference_pool

@ -20,17 +20,26 @@ class TransformerConnectionHandler(ConnectionHandler):
async def rpc_inference(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
request = await anext(requests)
backend = self.module_backends[request.uid]
assert isinstance(backend, TransformerBackend)
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
hidden_size = backend.module.hidden_size
cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
async with backend.memory_cache.allocate_cache(cache_descriptor) as handle:
inputs.insert(0, torch.tensor([handle], dtype=torch.int64))
outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
yield runtime_pb2.ExpertResponse(tensors=outputs)
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
try:
request = await anext(requests)
backend = self.module_backends[request.uid]
assert isinstance(backend, TransformerBackend)
# prepare attention cache
hidden_size = backend.module.hidden_size
cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, current_sequence_length]
cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
current_sequence_length = 0
async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
print("INPUTS:", inputs)
assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, current_sequence_length
outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
yield runtime_pb2.ExpertResponse(tensors=outputs)
current_sequence_length += inputs[1].shape[1]
finally:
print("CLOSED RPC_INFERENCE")
Loading…
Cancel
Save