pull/96/head
Aleksandr Borzunov 2 years ago
parent f0879f5d07
commit bf0be9f031

@ -321,15 +321,13 @@ class TransformerConnectionHandler(ConnectionHandler):
num_heads = backend.module.self_attention.num_heads
head_dim = backend.module.self_attention.head_dim
descr = TensorDescriptor(
size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
)
descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype)
# [key_or_value, batch_size, max_length, num_heads, head_dim]
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
gib = 1024 ** 3
gib = 1024**3
if backend is not None:
cur_size = backend.memory_cache.current_size_bytes
max_size = backend.memory_cache.max_size_bytes

Loading…
Cancel
Save