|
|
|
@ -321,15 +321,13 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
num_heads = backend.module.self_attention.num_heads
|
|
|
|
|
head_dim = backend.module.self_attention.head_dim
|
|
|
|
|
|
|
|
|
|
descr = TensorDescriptor(
|
|
|
|
|
size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
|
|
|
|
|
)
|
|
|
|
|
descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype)
|
|
|
|
|
# [key_or_value, batch_size, max_length, num_heads, head_dim]
|
|
|
|
|
|
|
|
|
|
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
|
|
|
|
|
total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
|
|
|
|
|
|
|
|
|
|
gib = 1024 ** 3
|
|
|
|
|
gib = 1024**3
|
|
|
|
|
if backend is not None:
|
|
|
|
|
cur_size = backend.memory_cache.current_size_bytes
|
|
|
|
|
max_size = backend.memory_cache.max_size_bytes
|
|
|
|
|