|
|
|
@ -119,12 +119,11 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
|
|
|
|
|
|
|
|
cache_metadata = torch.tensor(
|
|
|
|
|
[[-1, -1] for _ in range(batch_size)], dtype=torch.int64
|
|
|
|
|
) # [cache_handle, prefix_length]
|
|
|
|
|
[[-1, -1, -1] for _ in range(batch_size)], dtype=torch.int64
|
|
|
|
|
) # [cache_handle, rel_index, prefix_length]
|
|
|
|
|
prefix_length = 0
|
|
|
|
|
|
|
|
|
|
async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
|
|
|
|
|
assert len(cache_handles) == len(requested_backends)
|
|
|
|
|
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handle:
|
|
|
|
|
while request.tensors: # iterate while user is willing to supply tensors
|
|
|
|
|
hidden_states, prompts, hypo_ids = [
|
|
|
|
|
deserialize_torch_tensor(tensor) for tensor in request.tensors
|
|
|
|
@ -151,14 +150,16 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# run request tensors through all requested modules, update caches
|
|
|
|
|
for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
|
|
|
|
|
for rel_index, (backend, prompt) in enumerate(zip(requested_backends, prompts)):
|
|
|
|
|
if not is_dummy(prompt):
|
|
|
|
|
hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
|
|
if hidden_states.numel() == 0:
|
|
|
|
|
continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
|
|
|
|
|
# when user wants to pre-allocate cache or check that server *can* allocate that cache
|
|
|
|
|
|
|
|
|
|
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
|
|
|
|
|
cache_metadata[:] = torch.tensor(
|
|
|
|
|
[cache_handle, rel_index, prefix_length], dtype=torch.int64
|
|
|
|
|
)
|
|
|
|
|
assert isinstance(
|
|
|
|
|
hidden_states, torch.Tensor
|
|
|
|
|
), f"hidden states must be tensor, got {type(hidden_states)}"
|
|
|
|
@ -177,7 +178,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
type="inference",
|
|
|
|
|
)
|
|
|
|
|
(hidden_states,) = await backend.inference_pool.submit_task(
|
|
|
|
|
cache_metadata, hidden_states, hypo_ids, priority=priority
|
|
|
|
|
hidden_states, hypo_ids, cache_metadata, priority=priority
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# serialize and send last layer outputs
|
|
|
|
@ -343,33 +344,30 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
return tuple(uids)
|
|
|
|
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
|
|
async def _allocate_caches(
|
|
|
|
|
async def _allocate_cache(
|
|
|
|
|
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
|
|
|
|
|
) -> Sequence[int]:
|
|
|
|
|
"""Allocate memory caches for each transformer block, return cache handles"""
|
|
|
|
|
async with contextlib.AsyncExitStack() as stack:
|
|
|
|
|
handles = []
|
|
|
|
|
total_size = 0
|
|
|
|
|
backend = None
|
|
|
|
|
for backend in backends:
|
|
|
|
|
num_heads = backend.module.self_attention.num_heads
|
|
|
|
|
head_dim = backend.module.self_attention.head_dim
|
|
|
|
|
descr = TensorDescriptor(size=(2, batch_size, num_heads * head_dim * max_length), dtype=backend.dtype)
|
|
|
|
|
# ^-- flattened batch-first tensor of both keys and values; based on BLOOM layer_past layout
|
|
|
|
|
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
|
|
|
|
|
total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
|
|
|
|
|
|
|
|
gib = 1024**3
|
|
|
|
|
if backend is not None:
|
|
|
|
|
cur_size = backend.memory_cache.current_size_bytes
|
|
|
|
|
max_size = backend.memory_cache.max_size_bytes
|
|
|
|
|
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
|
|
|
|
|
cache_stats = f"used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
|
|
|
|
else:
|
|
|
|
|
cache_stats = f"cache stats n/a"
|
|
|
|
|
logger.info(f"rpc_inference.alloc(total_size={total_size / gib:.2f} GiB), {cache_stats}")
|
|
|
|
|
|
|
|
|
|
yield handles
|
|
|
|
|
"""Allocate memory cache for all transformer blocks, return cache handle"""
|
|
|
|
|
|
|
|
|
|
n_blocks = len(backends)
|
|
|
|
|
backend = backends[0]
|
|
|
|
|
n_heads = backend.module.self_attention.num_heads
|
|
|
|
|
head_dim = backend.module.self_attention.head_dim
|
|
|
|
|
descr = TensorDescriptor(size=(n_blocks, 2, batch_size, n_heads * head_dim * max_length), dtype=backend.dtype)
|
|
|
|
|
alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
|
|
|
|
|
|
|
|
gib = 1024**3
|
|
|
|
|
cur_size = backend.memory_cache.current_size_bytes
|
|
|
|
|
max_size = backend.memory_cache.max_size_bytes
|
|
|
|
|
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
|
|
|
|
|
logger.info(
|
|
|
|
|
f"rpc_inference.wait_for_alloc(size={alloc_size / gib:.2f} GiB), "
|
|
|
|
|
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async with backend.memory_cache.allocate_cache(descr) as handle:
|
|
|
|
|
logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)")
|
|
|
|
|
yield handle
|
|
|
|
|
|
|
|
|
|
def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None:
|
|
|
|
|
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
|
|
|
|
|