|
|
|
@ -14,6 +14,7 @@ from src.server.backend import MAX_LENGTH, TransformerBackend
|
|
|
|
|
|
|
|
|
|
class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
"""Handles three request types: forward, backward and forward-incremental (inference)"""
|
|
|
|
|
|
|
|
|
|
module_backends: Dict[ModuleUID, TransformerBackend]
|
|
|
|
|
|
|
|
|
|
def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
|
|
|
|
@ -42,18 +43,23 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
# run request tensors through all requested modules, update caches
|
|
|
|
|
for backend, cache_handle in zip(requested_backends, cache_handles):
|
|
|
|
|
cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
|
|
|
|
|
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3, \
|
|
|
|
|
f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
|
assert (
|
|
|
|
|
len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
|
|
|
|
|
|
|
hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
|
|
|
|
|
assert isinstance(hidden_states, (list, tuple))
|
|
|
|
|
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
|
|
|
|
|
|
|
# serialize and send last layer outputs
|
|
|
|
|
yield runtime_pb2.ExpertResponse(tensors=[
|
|
|
|
|
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
|
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
|
|
])
|
|
|
|
|
yield runtime_pb2.ExpertResponse(
|
|
|
|
|
tensors=[
|
|
|
|
|
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
|
|
for result, proto in zip(
|
|
|
|
|
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# prepare for next step
|
|
|
|
|
prefix_length += hidden_states[0].shape[1]
|
|
|
|
@ -63,7 +69,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
|
|
|
|
def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
|
|
|
|
|
"""Check that the first request to rpc_inference is valid"""
|
|
|
|
|
uids = (request.uid or '').split(CHAIN_DELIMITER)
|
|
|
|
|
uids = (request.uid or "").split(CHAIN_DELIMITER)
|
|
|
|
|
if not uids:
|
|
|
|
|
raise RuntimeError("User did not provide any uids")
|
|
|
|
|
for uid in uids:
|
|
|
|
@ -86,11 +92,3 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
|
|
|
|
|
|
|
|
|
|
yield handles
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|