forward_backward
Denis Mazur 3 months ago
parent efee5d1fa8
commit f8b922ca34

@ -590,3 +590,44 @@ class TransformerConnectionHandler(ConnectionHandler):
result.update(block_info)
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
async def _wrap_input_stream(stream):
while True:
expert_request = await anext(stream)
yield expert_request
print(expert_request.metadata)
if expert_request.metadata.get("SEP"):
break
async def rpc_forward_backward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
async with timeout(self.request_timeout):
wrapped_requests = self._wrap_input_stream(requests)
# Parse requests and prepare backends
uid_str, flat_inputs, metadata = await self._gather_inputs(wrapped_requests, context)
requested_uids = self._check_uids(uid_str)
self._log_request("rpc_forward_stream", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
hidden_states = await run_rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
print("EOS")
yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"EOS": True}))

Loading…
Cancel
Save