From 28971dcedd07c872f8883cf6d2cab479fc0200d9 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 6 Sep 2022 19:59:52 +0300 Subject: [PATCH] black-isort --- src/client/inference_session.py | 11 +++++++---- src/server/backend.py | 8 ++++++-- src/server/handler.py | 8 ++++++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/client/inference_session.py b/src/client/inference_session.py index 4b43eb1..bb1455f 100644 --- a/src/client/inference_session.py +++ b/src/client/inference_session.py @@ -71,9 +71,12 @@ class RemoteTransformerBlockInferenceSession: if not next_input_message.uid and not next_input_message.tensors: break # this message means "done sending" - def step(self, - new_hidden_states: torch.Tensor, - prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None): + def step( + self, + new_hidden_states: torch.Tensor, + prompts: Optional[torch.Tensor] = None, + hypo_ids: Optional[torch.Tensor] = None, + ): """ Inference step: send a chunk of input tesors and receive a chunk of outputs :prompts: optional DEEP prompts, added to a prefix of each layer's outputs, @@ -193,7 +196,7 @@ class RemoteSequentialInferenceSession: else: assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager) for session in self.inference_sessions: - outputs = session.step(inputs, prompts[self.chosen_spans[0].start: self.chosen_spans[0].end], **kwargs) + outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs) assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}" inputs = outputs return inputs diff --git a/src/server/backend.py b/src/server/backend.py index 8b67867..27ee1ad 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -70,7 +70,9 @@ class TransformerBackend(ModuleBackend): attention_cache_handle = int(cache_metadata[0, 0].item()) prefix_length = int(cache_metadata[0, 1].item()) (hidden_states, hypo_ids) = inputs - assert (hidden_states.ndim == 3), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" + assert ( + hidden_states.ndim == 3 + ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" with self.memory_cache.use_cache(attention_cache_handle) as cache: assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5 @@ -78,7 +80,9 @@ class TransformerBackend(ModuleBackend): cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length] print("METADATA:", cache_metadata, past_k.shape, past_v.shape) - hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) + hidden_states, (new_k, new_v) = self.module.forward( + hidden_states, layer_past=layer_past, use_cache=True + ) # todo remove these asserts once we pass all tests new_length = new_v.shape[1] diff --git a/src/server/handler.py b/src/server/handler.py index 13d0275..b2e15f7 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -95,8 +95,12 @@ class TransformerConnectionHandler(ConnectionHandler): assert isinstance( hidden_states, torch.Tensor ), f"hidden states must be tensor, got {type(hidden_states)}" - assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - (hidden_states,) = await backend.inference_pool.submit_task(cache_metadata, hidden_states, hypo_ids) + assert ( + hidden_states.ndim == 3 + ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" + (hidden_states,) = await backend.inference_pool.submit_task( + cache_metadata, hidden_states, hypo_ids + ) # serialize and send last layer outputs yield runtime_pb2.ExpertResponse(