Let users specify sequence length instead of assuming 2048 (#52)

- Maximum length is now provided in `.inference_session(max_length=100)`
   - previously, we would always assume max length = 2048
- added a generic way to forward **kwargs to inference session
  - for compatibility with #47 
  - Note to @borzunov : it does *not* pass them arbitrarily, but instead checks for kwarg names at the bottom level
- run_server can be started with a custom max_length for inference
- renamed --cache_size_bytes to --attention_cache_bytes (to avoid collision with --cache_dir)
- --attn_cache_bytes can now support humane file sizes (e.g. 300MB instead of 314572800)
- made some server-side errors more human-readable to user (e.g. when max length is exceeded)

Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
prompt-inference
justheuristic 2 years ago committed by GitHub
parent 948877149c
commit d271b75dd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -81,7 +81,7 @@ jobs:
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 \
--throughput 1 &> server1.log &
--throughput 1 --attn_cache_size 0.2GiB &> server1.log &
SERVER1_PID=$!
sleep 5 # wait for the first server to initialize DHT

@ -53,7 +53,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
loss.backward()
# test inference, one block
with layer3.inference_session() as sess:
with layer3.inference_session(max_length=10) as sess:
for i in range(10):
res = sess.step(torch.ones(1, 1, 4096))
```

@ -2,6 +2,7 @@ import configargparse
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from humanfriendly import parse_size
from src.server.server import Server
@ -32,16 +33,19 @@ def main():
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all expert operations')
parser.add_argument('--max_batch_size', type=int, default=16384,
help='The total number of examples in the same batch will not exceed this value')
help='The total number of tokens in the same batch will not exceed this value')
parser.add_argument('--inference_max_length', type=int, default=16384,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
parser.add_argument('--cache_size_bytes', type=int, default=None,
help='The size of memory cache for storing past attention keys/values between inference steps')
parser.add_argument('--device', type=str, default=None, required=False,
help='all experts will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument("--torch_dtype", type=str, default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--attn_cache_size', type=str, default=None,
help='The size of GPU memory allocated for storing past attention keys/values between inference'
' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
parser.add_argument('--revision', type=str, default='main',
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
@ -81,10 +85,17 @@ def main():
compression_type = args.pop("compression")
compression = getattr(CompressionType, compression_type)
attn_cache_size = args.pop("attn_cache_size")
if attn_cache_size is not None:
attn_cache_size = parse_size(attn_cache_size)
assert isinstance(
attn_cache_size, (int, type(None))
), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
use_auth_token = args.pop("use_auth_token")
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
server = Server.create(**args, start=True, compression=compression)
server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size)
try:
server.join()

@ -1,5 +1,6 @@
torch==1.12.0
accelerate==0.10.0
huggingface-hub==0.7.0
humanfriendly
https://github.com/learning-at-home/hivemind/archive/20b3b3d5f225ed525515a5383a008a8f9fad8173.zip
https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip

@ -7,6 +7,7 @@ from typing import AsyncIterator, List, Optional
import torch
from hivemind import (
P2P,
MSGPackSerializer,
anext,
deserialize_torch_tensor,
get_logger,
@ -33,23 +34,32 @@ class RemoteTransformerBlockInferenceSession:
:note: this inference session is *not* fault-tolerant out of the box
"""
def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
def __init__(
self,
uid: ModuleUID,
rpc_info: RPCInfo,
inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator,
*,
max_length: int,
):
self.uid, self.rpc_info = uid, rpc_info
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length))
self.stepped = False
self.closed = False
@classmethod
async def _create(
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
) -> RemoteTransformerBlockInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
inputs_queue = asyncio.Queue()
outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
return cls(uid, rpc_info, inputs_queue, outputs_stream)
return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
@ -73,6 +83,7 @@ class RemoteTransformerBlockInferenceSession:
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
],
metadata=self._serialized_metadata if not self.stepped else None,
)
)
)
@ -121,13 +132,14 @@ class RemoteSequentialInferenceSession:
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
"""
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None):
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
self.sequence_manager = sequence_manager
self.p2p = p2p
self.closed = False
self.chosen_spans: List[RemoteSpanInfo] = []
self.stack = contextlib.ExitStack()
self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
self.metadata = metadata
self.timeout = timeout
def __enter__(self):
@ -141,7 +153,7 @@ class RemoteSequentialInferenceSession:
span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
inference_session = RemoteExpertWorker.run_coroutine(
RemoteTransformerBlockInferenceSession._create(
stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
)
)
self.inference_sessions.append(inference_session)

@ -33,12 +33,8 @@ class RemoteTransformerBlock(RemoteExpert):
assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
return super().forward(inputs)
def inference_session(self) -> RemoteTransformerBlockInferenceSession:
def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession:
"""Initialize a new inference session with the specified remote server"""
return RemoteExpertWorker.run_coroutine(
RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info)
RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs)
)
def begin_inference_session(self):
logger.warning("beging_inference_session was renamed to just inference_session")
return self.inference_session()

@ -60,14 +60,20 @@ class RemoteGenerationMixin:
assert (
model_kwargs.get("stopping_criteria", None) is None
), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
if inputs is not None:
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
prefix_length = 0 if inputs is None else inputs.size(1)
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
if max_length is not None and max_new_tokens is None:
max_new_tokens = max_length - inputs.size(1)
max_new_tokens = max_length - prefix_length
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
elif max_length is None and max_new_tokens is not None:
max_length = prefix_length + max_new_tokens
if inputs is None:
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
@ -87,7 +93,7 @@ class RemoteGenerationMixin:
provided_constraints=provided_constraints,
)
with self.transformer.h.inference_session() as sess:
with self.transformer.h.inference_session(max_length=max_length) as sess:
outputs = []
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]

@ -79,9 +79,9 @@ class RemoteSequential(nn.Module):
def __len__(self):
return len(self.sequence_manager)
def inference_session(self) -> RemoteSequentialInferenceSession:
def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession:
self.sequence_manager.update_()
return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs)
def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

@ -14,8 +14,6 @@ from src.server.cache import MemoryCache
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
MAX_LENGTH = 2048
class InferenceTaskPool(TaskPool):
def __init__(self, *args, **kwargs):

@ -19,7 +19,7 @@ from hivemind.utils.asyncio import anext
from hivemind.utils.streaming import split_for_streaming
from src.data_structures import CHAIN_DELIMITER, ModuleUID
from src.server.backend import MAX_LENGTH, TransformerBackend
from src.server.backend import TransformerBackend
from src.utils.misc import DUMMY, is_dummy
@ -28,10 +28,11 @@ class TransformerConnectionHandler(ConnectionHandler):
module_backends: Dict[ModuleUID, TransformerBackend]
def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int):
super().__init__(dht, module_backends)
for module_backend in self.module_backends.values():
assert isinstance(module_backend, TransformerBackend)
self.inference_max_length = inference_max_length
async def rpc_inference(
self,
@ -43,7 +44,15 @@ class TransformerConnectionHandler(ConnectionHandler):
print("OPENED RPC_INFERENCE")
request = await anext(requests)
requested_uids = self._check_uids(request.uid)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
if not 0 <= max_length <= self.inference_max_length:
raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
batch_size = request.tensors[0].size[0] if request.tensors else 1
@ -52,10 +61,17 @@ class TransformerConnectionHandler(ConnectionHandler):
) # [cache_handle, prefix_length]
prefix_length = 0
async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
while request.tensors: # iterate while user is willing to supply tensors
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
length_increment = hidden_states[0].shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
# Cast inputs to backend dtype
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
@ -113,7 +129,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
# Serialize the overall output
serialized_output = [
@ -193,7 +209,9 @@ class TransformerConnectionHandler(ConnectionHandler):
return tuple(uids)
@contextlib.asynccontextmanager
async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]:
async def _allocate_caches(
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
) -> Sequence[int]:
"""Allocate memory caches for each transformer block, return cache handles"""
async with contextlib.AsyncExitStack() as stack:
handles = []
@ -202,7 +220,7 @@ class TransformerConnectionHandler(ConnectionHandler):
head_dim = backend.module.self_attention.head_dim
cache_descriptor = TensorDescriptor(
size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
)
# [key_or_value, batch_size, max_length, num_heads, head_dim]

@ -36,6 +36,7 @@ class Server(threading.Thread):
dht: DHT,
module_backends: Dict[str, TransformerBackend],
*,
inference_max_length: int,
num_connection_handlers: int = 8,
throughput: float,
update_period: float = 30,
@ -47,7 +48,8 @@ class Server(threading.Thread):
self.dht, self.module_backends = dht, module_backends
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
self.conn_handlers = [
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, **kwargs)
self.dht_handler_thread = ModuleAnnouncerThread(
@ -104,10 +106,11 @@ class Server(threading.Thread):
num_handlers: int = 8,
min_batch_size: int = 1,
max_batch_size: int = 4096,
inference_max_length: int = 4096,
torch_dtype: str = "auto",
revision: str = "main",
cache_dir: Optional[str] = None,
cache_size_bytes: Optional[int] = None,
attn_cache_size: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
initial_peers: Sequence[str] = (),
compression=CompressionType.NONE,
@ -141,7 +144,7 @@ class Server(threading.Thread):
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
memory_cache = MemoryCache(device, cache_size_bytes)
memory_cache = MemoryCache(device, attn_cache_size)
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
@ -228,6 +231,7 @@ class Server(threading.Thread):
blocks,
throughput=throughput,
num_connection_handlers=num_handlers,
inference_max_length=inference_max_length,
device=device,
stats_report_interval=stats_report_interval,
update_period=update_period,

@ -4,6 +4,7 @@ import hivemind
import pytest
import torch
import transformers
from hivemind import P2PHandlerError
from test_utils import *
from src.bloom.from_pretrained import load_pretrained_block
@ -27,9 +28,15 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
(outputs_forward,) = remote_block(inputs)
outputs_inference = []
with remote_block.inference_session() as sess:
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
# test that max length is respected
with pytest.raises(P2PHandlerError) as exc_info:
sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value)
outputs_inference = torch.cat(outputs_inference, dim=1)
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)

@ -63,7 +63,7 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
inputs = torch.randn(1, 8, config.hidden_size)
outputs_inference = []
with remote_block.inference_session() as sess:
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)

@ -31,7 +31,7 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
embs = model.transformer.word_embeddings(test_inputs)
embs = model.transformer.word_embeddings_layernorm(embs)
recurrent_outputs = []
with model.transformer.h.inference_session() as sess:
with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
for t in range(embs.shape[1]):
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)

Loading…
Cancel
Save