diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 60cb02f..3cb25ad 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -81,7 +81,7 @@ jobs: python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 \ - --throughput 1 &> server1.log & + --throughput 1 --attn_cache_size 0.2GiB &> server1.log & SERVER1_PID=$! sleep 5 # wait for the first server to initialize DHT diff --git a/README.md b/README.md index aca399f..b59606d 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ loss = (outputs * torch.randn_like(outputs)).norm() loss.backward() # test inference, one block -with layer3.inference_session() as sess: +with layer3.inference_session(max_length=10) as sess: for i in range(10): res = sess.step(torch.ones(1, 1, 4096)) ``` diff --git a/cli/run_server.py b/cli/run_server.py index 1cdae73..03055f7 100644 --- a/cli/run_server.py +++ b/cli/run_server.py @@ -2,6 +2,7 @@ import configargparse from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.limits import increase_file_limit from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from humanfriendly import parse_size from src.server.server import Server @@ -32,16 +33,19 @@ def main(): parser.add_argument('--min_batch_size', type=int, default=1, help='Minimum required batch size for all expert operations') parser.add_argument('--max_batch_size', type=int, default=16384, - help='The total number of examples in the same batch will not exceed this value') + help='The total number of tokens in the same batch will not exceed this value') + parser.add_argument('--inference_max_length', type=int, default=16384, + help='Maximum total sequence length permitted per inference, defaults to 16384 tokens') parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') - parser.add_argument('--cache_size_bytes', type=int, default=None, - help='The size of memory cache for storing past attention keys/values between inference steps') parser.add_argument('--device', type=str, default=None, required=False, help='all experts will use this device in torch notation; default: cuda if available else cpu') parser.add_argument("--torch_dtype", type=str, default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") + parser.add_argument('--attn_cache_size', type=str, default=None, + help='The size of GPU memory allocated for storing past attention keys/values between inference' + ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB') parser.add_argument('--revision', type=str, default='main', help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models" "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") @@ -81,10 +85,17 @@ def main(): compression_type = args.pop("compression") compression = getattr(CompressionType, compression_type) + attn_cache_size = args.pop("attn_cache_size") + if attn_cache_size is not None: + attn_cache_size = parse_size(attn_cache_size) + assert isinstance( + attn_cache_size, (int, type(None)) + ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)" + use_auth_token = args.pop("use_auth_token") args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token - server = Server.create(**args, start=True, compression=compression) + server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size) try: server.join() diff --git a/requirements.txt b/requirements.txt index feccf05..afc0290 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ torch==1.12.0 accelerate==0.10.0 huggingface-hub==0.7.0 +humanfriendly https://github.com/learning-at-home/hivemind/archive/20b3b3d5f225ed525515a5383a008a8f9fad8173.zip https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip \ No newline at end of file diff --git a/src/client/inference_session.py b/src/client/inference_session.py index 94f6ffa..24852df 100644 --- a/src/client/inference_session.py +++ b/src/client/inference_session.py @@ -7,6 +7,7 @@ from typing import AsyncIterator, List, Optional import torch from hivemind import ( P2P, + MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, @@ -33,23 +34,32 @@ class RemoteTransformerBlockInferenceSession: :note: this inference session is *not* fault-tolerant out of the box """ - def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator): + def __init__( + self, + uid: ModuleUID, + rpc_info: RPCInfo, + inputs_queue: asyncio.Queue, + outputs_aiter: AsyncIterator, + *, + max_length: int, + ): self.uid, self.rpc_info = uid, rpc_info # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread; # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter + self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length)) self.stepped = False self.closed = False @classmethod async def _create( - cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None + cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata ) -> RemoteTransformerBlockInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" inputs_queue = asyncio.Queue() outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout) - return cls(uid, rpc_info, inputs_queue, outputs_stream) + return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata) @staticmethod async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator: @@ -73,6 +83,7 @@ class RemoteTransformerBlockInferenceSession: serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"])) ], + metadata=self._serialized_metadata if not self.stepped else None, ) ) ) @@ -121,13 +132,14 @@ class RemoteSequentialInferenceSession: An interface to a multi-step *inference* session for a sequence of remote transformer blocks """ - def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None): + def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata): self.sequence_manager = sequence_manager self.p2p = p2p self.closed = False self.chosen_spans: List[RemoteSpanInfo] = [] self.stack = contextlib.ExitStack() self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = [] + self.metadata = metadata self.timeout = timeout def __enter__(self): @@ -141,7 +153,7 @@ class RemoteSequentialInferenceSession: span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end]) inference_session = RemoteExpertWorker.run_coroutine( RemoteTransformerBlockInferenceSession._create( - stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout + stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata ) ) self.inference_sessions.append(inference_session) diff --git a/src/client/remote_block.py b/src/client/remote_block.py index 68cd004..7d0f920 100644 --- a/src/client/remote_block.py +++ b/src/client/remote_block.py @@ -33,12 +33,8 @@ class RemoteTransformerBlock(RemoteExpert): assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})" return super().forward(inputs) - def inference_session(self) -> RemoteTransformerBlockInferenceSession: + def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession: """Initialize a new inference session with the specified remote server""" return RemoteExpertWorker.run_coroutine( - RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info) + RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs) ) - - def begin_inference_session(self): - logger.warning("beging_inference_session was renamed to just inference_session") - return self.inference_session() diff --git a/src/client/remote_generation.py b/src/client/remote_generation.py index d20462a..1b1b5cd 100644 --- a/src/client/remote_generation.py +++ b/src/client/remote_generation.py @@ -60,14 +60,20 @@ class RemoteGenerationMixin: assert ( model_kwargs.get("stopping_criteria", None) is None ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria" + if inputs is not None: + assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]" + prefix_length = 0 if inputs is None else inputs.size(1) bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)" if max_length is not None and max_new_tokens is None: - max_new_tokens = max_length - inputs.size(1) + max_new_tokens = max_length - prefix_length assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}" + elif max_length is None and max_new_tokens is not None: + max_length = prefix_length + max_new_tokens if inputs is None: assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs" @@ -87,7 +93,7 @@ class RemoteGenerationMixin: provided_constraints=provided_constraints, ) - with self.transformer.h.inference_session() as sess: + with self.transformer.h.inference_session(max_length=max_length) as sess: outputs = [] if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]] diff --git a/src/client/remote_sequential.py b/src/client/remote_sequential.py index 86acfe1..86cca85 100644 --- a/src/client/remote_sequential.py +++ b/src/client/remote_sequential.py @@ -79,9 +79,9 @@ class RemoteSequential(nn.Module): def __len__(self): return len(self.sequence_manager) - def inference_session(self) -> RemoteSequentialInferenceSession: + def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession: self.sequence_manager.update_() - return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p) + return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs) def extra_repr(self) -> str: return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" diff --git a/src/server/backend.py b/src/server/backend.py index 6a883f6..9929770 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -14,8 +14,6 @@ from src.server.cache import MemoryCache use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) -MAX_LENGTH = 2048 - class InferenceTaskPool(TaskPool): def __init__(self, *args, **kwargs): diff --git a/src/server/handler.py b/src/server/handler.py index c50a8aa..27ed562 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -19,7 +19,7 @@ from hivemind.utils.asyncio import anext from hivemind.utils.streaming import split_for_streaming from src.data_structures import CHAIN_DELIMITER, ModuleUID -from src.server.backend import MAX_LENGTH, TransformerBackend +from src.server.backend import TransformerBackend from src.utils.misc import DUMMY, is_dummy @@ -28,10 +28,11 @@ class TransformerConnectionHandler(ConnectionHandler): module_backends: Dict[ModuleUID, TransformerBackend] - def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]): + def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int): super().__init__(dht, module_backends) for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) + self.inference_max_length = inference_max_length async def rpc_inference( self, @@ -43,7 +44,15 @@ class TransformerConnectionHandler(ConnectionHandler): print("OPENED RPC_INFERENCE") request = await anext(requests) requested_uids = self._check_uids(request.uid) + metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + max_length = metadata.get("max_length") + + if not requested_uids: + raise ValueError("User must specify at least one block for inference, but got none") + assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}" + if not 0 <= max_length <= self.inference_max_length: + raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}") batch_size = request.tensors[0].size[0] if request.tensors else 1 @@ -52,10 +61,17 @@ class TransformerConnectionHandler(ConnectionHandler): ) # [cache_handle, prefix_length] prefix_length = 0 - async with self._allocate_caches(requested_backends, batch_size) as cache_handles: + async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles: assert len(cache_handles) == len(requested_backends) while request.tensors: # iterate while user is willing to supply tensors hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] + length_increment = hidden_states[0].shape[1] # how many tokens are added this step (in each seq) + + if prefix_length + length_increment > max_length: + raise ValueError( + f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" + f" exceeds pre-allocated maximum {max_length}" + ) # Cast inputs to backend dtype hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states] @@ -113,7 +129,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends) - assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3 + assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor" # Serialize the overall output serialized_output = [ @@ -193,7 +209,9 @@ class TransformerConnectionHandler(ConnectionHandler): return tuple(uids) @contextlib.asynccontextmanager - async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]: + async def _allocate_caches( + self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int + ) -> Sequence[int]: """Allocate memory caches for each transformer block, return cache handles""" async with contextlib.AsyncExitStack() as stack: handles = [] @@ -202,7 +220,7 @@ class TransformerConnectionHandler(ConnectionHandler): head_dim = backend.module.self_attention.head_dim cache_descriptor = TensorDescriptor( - size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype + size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype ) # [key_or_value, batch_size, max_length, num_heads, head_dim] diff --git a/src/server/server.py b/src/server/server.py index 057daa5..5d92bd9 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -36,6 +36,7 @@ class Server(threading.Thread): dht: DHT, module_backends: Dict[str, TransformerBackend], *, + inference_max_length: int, num_connection_handlers: int = 8, throughput: float, update_period: float = 30, @@ -47,7 +48,8 @@ class Server(threading.Thread): self.dht, self.module_backends = dht, module_backends self.throughput, self.update_period, self.expiration = throughput, update_period, expiration self.conn_handlers = [ - TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers) + TransformerConnectionHandler(dht, self.module_backends, inference_max_length) + for _ in range(num_connection_handlers) ] self.runtime = Runtime(self.module_backends, **kwargs) self.dht_handler_thread = ModuleAnnouncerThread( @@ -104,10 +106,11 @@ class Server(threading.Thread): num_handlers: int = 8, min_batch_size: int = 1, max_batch_size: int = 4096, + inference_max_length: int = 4096, torch_dtype: str = "auto", revision: str = "main", cache_dir: Optional[str] = None, - cache_size_bytes: Optional[int] = None, + attn_cache_size: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, initial_peers: Sequence[str] = (), compression=CompressionType.NONE, @@ -141,7 +144,7 @@ class Server(threading.Thread): logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") device = device or ("cuda" if torch.cuda.is_available() else "cpu") - memory_cache = MemoryCache(device, cache_size_bytes) + memory_cache = MemoryCache(device, attn_cache_size) assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: @@ -228,6 +231,7 @@ class Server(threading.Thread): blocks, throughput=throughput, num_connection_handlers=num_handlers, + inference_max_length=inference_max_length, device=device, stats_report_interval=stats_report_interval, update_period=update_period, diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index caac346..4761aea 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -4,6 +4,7 @@ import hivemind import pytest import torch import transformers +from hivemind import P2PHandlerError from test_utils import * from src.bloom.from_pretrained import load_pretrained_block @@ -27,9 +28,15 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3): (outputs_forward,) = remote_block(inputs) outputs_inference = [] - with remote_block.inference_session() as sess: + with remote_block.inference_session(max_length=inputs.shape[1]) as sess: for i in range(inputs.shape[1]): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) + + # test that max length is respected + with pytest.raises(P2PHandlerError) as exc_info: + sess.step(inputs[:, -1:, :]) + assert "Maximum length exceeded" in repr(exc_info.value) + outputs_inference = torch.cat(outputs_inference, dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 84c4232..8148286 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -63,7 +63,7 @@ def test_chained_inference_exact_match(atol_inference=1e-4): inputs = torch.randn(1, 8, config.hidden_size) outputs_inference = [] - with remote_block.inference_session() as sess: + with remote_block.inference_session(max_length=inputs.shape[1]) as sess: for i in range(inputs.shape[1]): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) outputs_inference = torch.cat(outputs_inference, dim=1) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 47f08be..b0ce824 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -31,7 +31,7 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3): embs = model.transformer.word_embeddings(test_inputs) embs = model.transformer.word_embeddings_layernorm(embs) recurrent_outputs = [] - with model.transformer.h.inference_session() as sess: + with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess: for t in range(embs.shape[1]): recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) recurrent_outputs = torch.cat(recurrent_outputs, dim=1)