|
|
|
@ -75,10 +75,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
context: P2PContext,
|
|
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
|
|
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
|
|
|
|
|
|
|
|
|
request = await anext(requests)
|
|
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
|
|
self._log_request("rpc_inference.open", requested_uids, context)
|
|
|
|
|
try:
|
|
|
|
|
logger.debug("Opened rpc_inference()")
|
|
|
|
|
request = await anext(requests)
|
|
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
max_length = metadata.get("max_length")
|
|
|
|
@ -167,12 +168,14 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
prefix_length += hidden_states.shape[1]
|
|
|
|
|
request = await (anext(requests))
|
|
|
|
|
finally:
|
|
|
|
|
logger.debug("Closed rpc_inference()")
|
|
|
|
|
self._log_request("rpc_inference.close", requested_uids, context)
|
|
|
|
|
|
|
|
|
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
|
|
# Parse request and prepare backends
|
|
|
|
|
flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
|
|
self._log_request("rpc_forward", requested_uids, context)
|
|
|
|
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
|
|
points = metadata.get("points", 0)
|
|
|
|
@ -199,6 +202,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
# Parse requests and prepare backends
|
|
|
|
|
uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
|
|
|
|
|
requested_uids = self._check_uids(uid_str)
|
|
|
|
|
self._log_request("rpc_forward_stream", requested_uids, context)
|
|
|
|
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
points = metadata.get("points", 0)
|
|
|
|
|
assert isinstance(
|
|
|
|
@ -227,6 +232,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
# Parse requests and prepare backends
|
|
|
|
|
flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
|
|
self._log_request("rpc_backward", requested_uids, context)
|
|
|
|
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
|
|
points = metadata.get("points", 0)
|
|
|
|
@ -257,9 +264,10 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
async def rpc_backward_stream(
|
|
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
|
|
|
|
|
|
|
uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
|
|
|
|
|
requested_uids = self._check_uids(uids_header)
|
|
|
|
|
self._log_request("rpc_backward_stream", requested_uids, context)
|
|
|
|
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
points = metadata.get("points", 0)
|
|
|
|
|
assert isinstance(
|
|
|
|
@ -307,19 +315,39 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
"""Allocate memory caches for each transformer block, return cache handles"""
|
|
|
|
|
async with contextlib.AsyncExitStack() as stack:
|
|
|
|
|
handles = []
|
|
|
|
|
total_size = 0
|
|
|
|
|
backend = None
|
|
|
|
|
for backend in backends:
|
|
|
|
|
num_heads = backend.module.self_attention.num_heads
|
|
|
|
|
head_dim = backend.module.self_attention.head_dim
|
|
|
|
|
|
|
|
|
|
cache_descriptor = TensorDescriptor(
|
|
|
|
|
size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
|
|
|
|
|
)
|
|
|
|
|
descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype)
|
|
|
|
|
# [key_or_value, batch_size, max_length, num_heads, head_dim]
|
|
|
|
|
|
|
|
|
|
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
|
|
|
|
|
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
|
|
|
|
|
total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
|
|
|
|
|
|
|
|
gib = 1024**3
|
|
|
|
|
if backend is not None:
|
|
|
|
|
cur_size = backend.memory_cache.current_size_bytes
|
|
|
|
|
max_size = backend.memory_cache.max_size_bytes
|
|
|
|
|
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
|
|
|
|
|
cache_stats = f"used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
|
|
|
|
else:
|
|
|
|
|
cache_stats = f"cache stats n/a"
|
|
|
|
|
logger.info(f"rpc_inference.alloc(total_size={total_size / gib:.2f} GiB), {cache_stats}")
|
|
|
|
|
|
|
|
|
|
yield handles
|
|
|
|
|
|
|
|
|
|
def _log_request(self, method: str, uids: List[ModuleUID], context: P2PContext) -> None:
|
|
|
|
|
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
|
|
|
|
|
friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
|
|
|
|
|
friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
|
|
|
|
|
|
|
|
|
|
friendly_remote_id = "..." + str(context.remote_id)[-6:]
|
|
|
|
|
|
|
|
|
|
logger.info(f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _rpc_forward(
|
|
|
|
|
*flat_tensors: torch.Tensor,
|
|
|
|
|