|
|
|
@ -20,17 +20,26 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
async def rpc_inference(
|
|
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
|
|
|
|
|
|
|
request = await anext(requests)
|
|
|
|
|
backend = self.module_backends[request.uid]
|
|
|
|
|
assert isinstance(backend, TransformerBackend)
|
|
|
|
|
|
|
|
|
|
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
|
|
|
|
|
|
hidden_size = backend.module.hidden_size
|
|
|
|
|
cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
|
|
|
|
|
async with backend.memory_cache.allocate_cache(cache_descriptor) as handle:
|
|
|
|
|
inputs.insert(0, torch.tensor([handle], dtype=torch.int64))
|
|
|
|
|
outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
|
|
|
|
|
|
|
|
|
|
yield runtime_pb2.ExpertResponse(tensors=outputs)
|
|
|
|
|
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
|
|
|
|
try:
|
|
|
|
|
request = await anext(requests)
|
|
|
|
|
backend = self.module_backends[request.uid]
|
|
|
|
|
assert isinstance(backend, TransformerBackend)
|
|
|
|
|
|
|
|
|
|
# prepare attention cache
|
|
|
|
|
hidden_size = backend.module.hidden_size
|
|
|
|
|
cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, current_sequence_length]
|
|
|
|
|
cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
|
|
|
|
|
current_sequence_length = 0
|
|
|
|
|
|
|
|
|
|
async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
|
|
|
|
|
inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
|
|
|
|
|
print("INPUTS:", inputs)
|
|
|
|
|
assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
|
|
|
|
|
cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, current_sequence_length
|
|
|
|
|
outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
|
|
|
|
|
yield runtime_pb2.ExpertResponse(tensors=outputs)
|
|
|
|
|
|
|
|
|
|
current_sequence_length += inputs[1].shape[1]
|
|
|
|
|
finally:
|
|
|
|
|
print("CLOSED RPC_INFERENCE")
|