Fix tests

generation
Artem Chumachenko 2 years ago
parent d351431e95
commit 1d7c550485

@ -100,7 +100,7 @@ class RemoteTransformerBlockInferenceSession:
)
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
return outputs
return outputs[0]
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""

@ -141,7 +141,7 @@ class RemoteSequentialInferenceSession:
def step(self, inputs: torch.Tensor):
assert not self.closed
for session in self.active_sessions:
outputs = session.step(inputs)[0]
outputs = session.step(inputs)
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
inputs = outputs
return inputs

Loading…
Cancel
Save