|
|
|
@ -590,3 +590,44 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
result.update(block_info)
|
|
|
|
|
|
|
|
|
|
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
|
|
|
|
|
|
|
|
|
|
async def _wrap_input_stream(stream):
|
|
|
|
|
while True:
|
|
|
|
|
expert_request = await anext(stream)
|
|
|
|
|
yield expert_request
|
|
|
|
|
print(expert_request.metadata)
|
|
|
|
|
if expert_request.metadata.get("SEP"):
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
async def rpc_forward_backward_stream(
|
|
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
|
|
async with timeout(self.request_timeout):
|
|
|
|
|
wrapped_requests = self._wrap_input_stream(requests)
|
|
|
|
|
|
|
|
|
|
# Parse requests and prepare backends
|
|
|
|
|
uid_str, flat_inputs, metadata = await self._gather_inputs(wrapped_requests, context)
|
|
|
|
|
requested_uids = self._check_uids(uid_str)
|
|
|
|
|
self._log_request("rpc_forward_stream", requested_uids, context)
|
|
|
|
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
active_adapter = self._get_active_adapter(metadata)
|
|
|
|
|
points = metadata.get("points", 0)
|
|
|
|
|
args_structure = metadata.get("args_structure")
|
|
|
|
|
assert isinstance(
|
|
|
|
|
points, (float, int)
|
|
|
|
|
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
|
|
|
|
hidden_states = await run_rpc_forward(
|
|
|
|
|
*flat_inputs,
|
|
|
|
|
requested_backends=requested_backends,
|
|
|
|
|
prioritizer=self._prioritizer,
|
|
|
|
|
active_adapter=active_adapter,
|
|
|
|
|
points=points,
|
|
|
|
|
args_structure=args_structure,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
|
|
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
|
|
print("EOS")
|
|
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"EOS": True}))
|
|
|
|
|