black-isort

pull/65/head
justheuristic 2 years ago
parent 5af1c9e3b4
commit 28971dcedd

@ -71,9 +71,12 @@ class RemoteTransformerBlockInferenceSession:
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"
def step(self,
new_hidden_states: torch.Tensor,
prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None):
def step(
self,
new_hidden_states: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
):
"""
Inference step: send a chunk of input tesors and receive a chunk of outputs
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
@ -193,7 +196,7 @@ class RemoteSequentialInferenceSession:
else:
assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
for session in self.inference_sessions:
outputs = session.step(inputs, prompts[self.chosen_spans[0].start: self.chosen_spans[0].end], **kwargs)
outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
inputs = outputs
return inputs

@ -70,7 +70,9 @@ class TransformerBackend(ModuleBackend):
attention_cache_handle = int(cache_metadata[0, 0].item())
prefix_length = int(cache_metadata[0, 1].item())
(hidden_states, hypo_ids) = inputs
assert (hidden_states.ndim == 3), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
assert (
hidden_states.ndim == 3
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
with self.memory_cache.use_cache(attention_cache_handle) as cache:
assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
@ -78,7 +80,9 @@ class TransformerBackend(ModuleBackend):
cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
hidden_states, (new_k, new_v) = self.module.forward(
hidden_states, layer_past=layer_past, use_cache=True
)
# todo remove these asserts once we pass all tests
new_length = new_v.shape[1]

@ -95,8 +95,12 @@ class TransformerConnectionHandler(ConnectionHandler):
assert isinstance(
hidden_states, torch.Tensor
), f"hidden states must be tensor, got {type(hidden_states)}"
assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
(hidden_states,) = await backend.inference_pool.submit_task(cache_metadata, hidden_states, hypo_ids)
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
(hidden_states,) = await backend.inference_pool.submit_task(
cache_metadata, hidden_states, hypo_ids
)
# serialize and send last layer outputs
yield runtime_pb2.ExpertResponse(

Loading…
Cancel
Save