forward_backward
Denis Mazur 1 month ago
parent 16619f4564
commit 26d4cd855d

File diff suppressed because one or more lines are too long

@ -592,7 +592,7 @@ class TransformerConnectionHandler(ConnectionHandler):
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))
@staticmethod
async def _wrap_input_stream(stream):
async def _read_until_eos(stream):
while True:
expert_request = await anext(stream)
yield expert_request
@ -605,10 +605,9 @@ class TransformerConnectionHandler(ConnectionHandler):
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
async with timeout(self.request_timeout):
wrapped_requests = self._wrap_input_stream(requests)
# Parse requests and prepare backends
uid_str, flat_inputs, metadata = await self._gather_inputs(wrapped_requests, context)
uid_str, flat_inputs, metadata = await self._gather_inputs(self._read_until_eos(requests), context)
requested_uids = self._check_uids(uid_str)
self._log_request("rpc_forward_stream", requested_uids, context)
@ -620,6 +619,8 @@ class TransformerConnectionHandler(ConnectionHandler):
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
print(f"{requested_backends=}, {active_adapter=}, {points=}, {args_structure=}")
hidden_states = await run_rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
@ -632,4 +633,36 @@ class TransformerConnectionHandler(ConnectionHandler):
for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
print("EOS")
yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"EOS": True}))
yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"_EOS": True}))
####
new_uid_str, flat_extra_inputs, extra_metadata = await self._gather_inputs(self._read_until_eos(requests), context)
backward_args_structure = extra_metadata.get("args_structure")
assert len(flat_extra_inputs) == 1
assert new_uid_str == uid_str
print("I solemnly swear to think about how to use extra_metadata for pushing when it comes to this")
grad_outputs, = flat_extra_inputs
print("HERE!")
print("FLAT INPUTS", flat_inputs)
print("GRAD OUTPUTS", grad_outputs)
print(backward_args_structure)
grads = await run_rpc_backward(
flat_inputs[0],
grad_outputs,
*flat_inputs[1:],
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=backward_args_structure,
)
# Split the serialized_grad_inputs for streaming and respond
for tensor in self._serialize_grads(grads, requested_backends, metadata):
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
print("SENDING GRADS:", part)
yield runtime_pb2.ExpertResponse(tensors=[part])

Loading…
Cancel
Save