Improve server's logging (#96)

Log all RPC calls with block indices and shortened peer IDs, print attention cache stats.
server-timeouts
Alexander Borzunov 2 years ago committed by GitHub
parent fdb3583a8c
commit d8ef09146e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,5 +4,5 @@ accelerate==0.10.0
huggingface-hub==0.7.0
transformers==4.21.3
protobuf>=3.20.3,<4.0dev
git+https://github.com/learning-at-home/hivemind@1e4af434f35ad43208e7e5df569c5ff5eb79681b
git+https://github.com/learning-at-home/hivemind@be88b4280cdd87432168e1da238e532f1364078b
humanfriendly

@ -75,10 +75,11 @@ class TransformerConnectionHandler(ConnectionHandler):
context: P2PContext,
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
request = await anext(requests)
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_inference.open", requested_uids, context)
try:
logger.debug("Opened rpc_inference()")
request = await anext(requests)
requested_uids = self._check_uids(request.uid)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
@ -167,12 +168,14 @@ class TransformerConnectionHandler(ConnectionHandler):
prefix_length += hidden_states.shape[1]
request = await (anext(requests))
finally:
logger.debug("Closed rpc_inference()")
self._log_request("rpc_inference.close", requested_uids, context)
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
# Parse request and prepare backends
flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_forward", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
points = metadata.get("points", 0)
@ -199,6 +202,8 @@ class TransformerConnectionHandler(ConnectionHandler):
# Parse requests and prepare backends
uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
requested_uids = self._check_uids(uid_str)
self._log_request("rpc_forward_stream", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
points = metadata.get("points", 0)
assert isinstance(
@ -227,6 +232,8 @@ class TransformerConnectionHandler(ConnectionHandler):
# Parse requests and prepare backends
flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_backward", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
points = metadata.get("points", 0)
@ -257,9 +264,10 @@ class TransformerConnectionHandler(ConnectionHandler):
async def rpc_backward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
requested_uids = self._check_uids(uids_header)
self._log_request("rpc_backward_stream", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
points = metadata.get("points", 0)
assert isinstance(
@ -307,19 +315,39 @@ class TransformerConnectionHandler(ConnectionHandler):
"""Allocate memory caches for each transformer block, return cache handles"""
async with contextlib.AsyncExitStack() as stack:
handles = []
total_size = 0
backend = None
for backend in backends:
num_heads = backend.module.self_attention.num_heads
head_dim = backend.module.self_attention.head_dim
cache_descriptor = TensorDescriptor(
size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
)
descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype)
# [key_or_value, batch_size, max_length, num_heads, head_dim]
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
gib = 1024**3
if backend is not None:
cur_size = backend.memory_cache.current_size_bytes
max_size = backend.memory_cache.max_size_bytes
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
cache_stats = f"used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
else:
cache_stats = f"cache stats n/a"
logger.info(f"rpc_inference.alloc(total_size={total_size / gib:.2f} GiB), {cache_stats}")
yield handles
def _log_request(self, method: str, uids: List[ModuleUID], context: P2PContext) -> None:
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
friendly_remote_id = "..." + str(context.remote_id)[-6:]
logger.info(f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})")
async def _rpc_forward(
*flat_tensors: torch.Tensor,

@ -379,11 +379,6 @@ class ModuleContainer(threading.Thread):
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
runs Runtime (self.runtime) to process incoming requests.
"""
logger.info(f"Serving {len(self.module_backends)} blocks:")
for expert_name, backend in self.module_backends.items():
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
if not self.dht.is_alive():
self.dht.run_in_background(await_ready=True)

Loading…
Cancel
Save