Alloc inference cache as one contiguous buffer (#160)

pull/161/head
Alexander Borzunov 1 year ago committed by GitHub
parent 523a7cad33
commit 7cdc57a04b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,25 +48,25 @@ class TransformerBackend(ModuleBackend):
self.kwargs_schema,
)
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
def inference_step(
self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, cache_metadata: torch.LongTensor
) -> Tuple[torch.Tensor, ...]:
num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim
with torch.inference_mode():
attention_cache_handle = int(cache_metadata[0, 0].item())
prefix_length = int(cache_metadata[0, 1].item())
(hidden_states, hypo_ids) = inputs
assert (
hidden_states.ndim == 3
), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
cache_handle, rel_index, prefix_length = map(int, cache_metadata[0])
with self.memory_cache.use_cache(attention_cache_handle) as cache:
batch_size = cache.shape[1]
max_length = cache.numel() // (2 * batch_size * head_dim * num_heads)
assert isinstance(self.module, WrappedBloomBlock) and cache.shape[0] == 2 and cache.ndim == 3
with self.memory_cache.use_cache(cache_handle) as cache:
batch_size = cache.shape[2]
max_length = cache.shape[-1] // (head_dim * num_heads)
assert isinstance(self.module, WrappedBloomBlock) and cache.shape[1] == 2 and cache.ndim == 4
if not is_dummy(hypo_ids):
assert hypo_ids.shape[0] == cache.shape[1]
cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids
key_cache = cache[0].view(batch_size, num_heads, head_dim, max_length)
value_cache = cache[1].view(batch_size, num_heads, max_length, head_dim)
assert hypo_ids.shape[0] == batch_size
cache[rel_index, :, :] = cache[rel_index, :, hypo_ids] # in-place reorder cache by hypo ids
key_cache = cache[rel_index, 0].view(batch_size, num_heads, head_dim, max_length)
value_cache = cache[rel_index, 1].view(batch_size, num_heads, max_length, head_dim)
key_past = key_cache.flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length]
value_past = value_cache.flatten(0, 1)[:, :prefix_length, :] # [batch * num_heads, kv_length, head_dim]

@ -119,12 +119,11 @@ class TransformerConnectionHandler(ConnectionHandler):
batch_size = request.tensors[0].size[0] if request.tensors else 1
cache_metadata = torch.tensor(
[[-1, -1] for _ in range(batch_size)], dtype=torch.int64
) # [cache_handle, prefix_length]
[[-1, -1, -1] for _ in range(batch_size)], dtype=torch.int64
) # [cache_handle, rel_index, prefix_length]
prefix_length = 0
async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handle:
while request.tensors: # iterate while user is willing to supply tensors
hidden_states, prompts, hypo_ids = [
deserialize_torch_tensor(tensor) for tensor in request.tensors
@ -151,14 +150,16 @@ class TransformerConnectionHandler(ConnectionHandler):
)
# run request tensors through all requested modules, update caches
for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
for rel_index, (backend, prompt) in enumerate(zip(requested_backends, prompts)):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
if hidden_states.numel() == 0:
continue # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
# when user wants to pre-allocate cache or check that server *can* allocate that cache
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
cache_metadata[:] = torch.tensor(
[cache_handle, rel_index, prefix_length], dtype=torch.int64
)
assert isinstance(
hidden_states, torch.Tensor
), f"hidden states must be tensor, got {type(hidden_states)}"
@ -177,7 +178,7 @@ class TransformerConnectionHandler(ConnectionHandler):
type="inference",
)
(hidden_states,) = await backend.inference_pool.submit_task(
cache_metadata, hidden_states, hypo_ids, priority=priority
hidden_states, hypo_ids, cache_metadata, priority=priority
)
# serialize and send last layer outputs
@ -343,33 +344,30 @@ class TransformerConnectionHandler(ConnectionHandler):
return tuple(uids)
@contextlib.asynccontextmanager
async def _allocate_caches(
async def _allocate_cache(
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
) -> Sequence[int]:
"""Allocate memory caches for each transformer block, return cache handles"""
async with contextlib.AsyncExitStack() as stack:
handles = []
total_size = 0
backend = None
for backend in backends:
num_heads = backend.module.self_attention.num_heads
head_dim = backend.module.self_attention.head_dim
descr = TensorDescriptor(size=(2, batch_size, num_heads * head_dim * max_length), dtype=backend.dtype)
# ^-- flattened batch-first tensor of both keys and values; based on BLOOM layer_past layout
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
gib = 1024**3
if backend is not None:
cur_size = backend.memory_cache.current_size_bytes
max_size = backend.memory_cache.max_size_bytes
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
cache_stats = f"used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
else:
cache_stats = f"cache stats n/a"
logger.info(f"rpc_inference.alloc(total_size={total_size / gib:.2f} GiB), {cache_stats}")
yield handles
"""Allocate memory cache for all transformer blocks, return cache handle"""
n_blocks = len(backends)
backend = backends[0]
n_heads = backend.module.self_attention.num_heads
head_dim = backend.module.self_attention.head_dim
descr = TensorDescriptor(size=(n_blocks, 2, batch_size, n_heads * head_dim * max_length), dtype=backend.dtype)
alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
gib = 1024**3
cur_size = backend.memory_cache.current_size_bytes
max_size = backend.memory_cache.max_size_bytes
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
logger.info(
f"rpc_inference.wait_for_alloc(size={alloc_size / gib:.2f} GiB), "
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
)
async with backend.memory_cache.allocate_cache(descr) as handle:
logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)")
yield handle
def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None:
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]

Loading…
Cancel
Save