support various dtype for inference

pull/39/head
dbaranchuk 2 years ago
parent b8a78b8254
commit 6165aca545

@ -70,7 +70,7 @@ class RemoteTransformerBlockInferenceSession:
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor, proto.compression)
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
],
)

@ -48,6 +48,9 @@ class TransformerConnectionHandler(ConnectionHandler):
while request.tensors: # iterate while user is willing to supply tensors
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
# Cast inputs to backend dtype
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
# run request tensors through all requested modules, update caches
for backend, cache_handle in zip(requested_backends, cache_handles):
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
@ -62,7 +65,7 @@ class TransformerConnectionHandler(ConnectionHandler):
# serialize and send last layer outputs
yield runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
)
@ -242,7 +245,7 @@ class TransformerConnectionHandler(ConnectionHandler):
head_dim = backend.module.self_attention.head_dim
cache_descriptor = TensorDescriptor(
size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32
size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
)
# [key_or_value, batch_size, max_length, num_heads, head_dim]

Loading…
Cancel
Save