|
|
|
@ -592,7 +592,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def _wrap_input_stream(stream):
|
|
|
|
|
async def _read_until_eos(stream):
|
|
|
|
|
while True:
|
|
|
|
|
expert_request = await anext(stream)
|
|
|
|
|
yield expert_request
|
|
|
|
@ -605,10 +605,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
|
|
async with timeout(self.request_timeout):
|
|
|
|
|
wrapped_requests = self._wrap_input_stream(requests)
|
|
|
|
|
|
|
|
|
|
# Parse requests and prepare backends
|
|
|
|
|
uid_str, flat_inputs, metadata = await self._gather_inputs(wrapped_requests, context)
|
|
|
|
|
uid_str, flat_inputs, metadata = await self._gather_inputs(self._read_until_eos(requests), context)
|
|
|
|
|
requested_uids = self._check_uids(uid_str)
|
|
|
|
|
self._log_request("rpc_forward_stream", requested_uids, context)
|
|
|
|
|
|
|
|
|
@ -620,6 +619,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
points, (float, int)
|
|
|
|
|
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
|
|
|
|
print(f"{requested_backends=}, {active_adapter=}, {points=}, {args_structure=}")
|
|
|
|
|
|
|
|
|
|
hidden_states = await run_rpc_forward(
|
|
|
|
|
*flat_inputs,
|
|
|
|
|
requested_backends=requested_backends,
|
|
|
|
@ -632,4 +633,36 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
|
|
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
|
|
print("EOS")
|
|
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"EOS": True}))
|
|
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"_EOS": True}))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
####
|
|
|
|
|
new_uid_str, flat_extra_inputs, extra_metadata = await self._gather_inputs(self._read_until_eos(requests), context)
|
|
|
|
|
backward_args_structure = extra_metadata.get("args_structure")
|
|
|
|
|
assert len(flat_extra_inputs) == 1
|
|
|
|
|
assert new_uid_str == uid_str
|
|
|
|
|
print("I solemnly swear to think about how to use extra_metadata for pushing when it comes to this")
|
|
|
|
|
grad_outputs, = flat_extra_inputs
|
|
|
|
|
|
|
|
|
|
print("HERE!")
|
|
|
|
|
|
|
|
|
|
print("FLAT INPUTS", flat_inputs)
|
|
|
|
|
print("GRAD OUTPUTS", grad_outputs)
|
|
|
|
|
print(backward_args_structure)
|
|
|
|
|
|
|
|
|
|
grads = await run_rpc_backward(
|
|
|
|
|
flat_inputs[0],
|
|
|
|
|
grad_outputs,
|
|
|
|
|
*flat_inputs[1:],
|
|
|
|
|
requested_backends=requested_backends,
|
|
|
|
|
prioritizer=self._prioritizer,
|
|
|
|
|
active_adapter=active_adapter,
|
|
|
|
|
points=points,
|
|
|
|
|
args_structure=backward_args_structure,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Split the serialized_grad_inputs for streaming and respond
|
|
|
|
|
for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
|
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
|
|
print("SENDING GRADS:", part)
|
|
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|