|
|
|
@ -48,6 +48,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
while request.tensors: # iterate while user is willing to supply tensors
|
|
|
|
|
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
|
|
|
|
|
|
# Cast inputs to backend dtype
|
|
|
|
|
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
|
|
|
|
|
|
|
|
|
|
# run request tensors through all requested modules, update caches
|
|
|
|
|
for backend, cache_handle in zip(requested_backends, cache_handles):
|
|
|
|
|
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
|
|
|
|
@ -62,7 +65,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
# serialize and send last layer outputs
|
|
|
|
|
yield runtime_pb2.ExpertResponse(
|
|
|
|
|
tensors=[
|
|
|
|
|
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
|
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
|
|
for result, proto in zip(
|
|
|
|
|
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
|
|
|
|
|
)
|
|
|
|
@ -242,7 +245,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
head_dim = backend.module.self_attention.head_dim
|
|
|
|
|
|
|
|
|
|
cache_descriptor = TensorDescriptor(
|
|
|
|
|
size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32
|
|
|
|
|
size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
|
|
|
|
|
)
|
|
|
|
|
# [key_or_value, batch_size, max_length, num_heads, head_dim]
|
|
|
|
|
|
|
|
|
|