From d271b75dd42ebd2e18ee76434a70d5607e83039f Mon Sep 17 00:00:00 2001 From: justheuristic Date: Mon, 29 Aug 2022 21:04:37 +0300 Subject: [PATCH 01/10] Let users specify sequence length instead of assuming 2048 (#52) - Maximum length is now provided in `.inference_session(max_length=100)` - previously, we would always assume max length = 2048 - added a generic way to forward **kwargs to inference session - for compatibility with #47 - Note to @borzunov : it does *not* pass them arbitrarily, but instead checks for kwarg names at the bottom level - run_server can be started with a custom max_length for inference - renamed --cache_size_bytes to --attention_cache_bytes (to avoid collision with --cache_dir) - --attn_cache_bytes can now support humane file sizes (e.g. 300MB instead of 314572800) - made some server-side errors more human-readable to user (e.g. when max length is exceeded) Co-authored-by: Aleksandr Borzunov Co-authored-by: Alexander Borzunov --- .github/workflows/run-tests.yaml | 2 +- README.md | 2 +- cli/run_server.py | 19 +++++++++++++++---- requirements.txt | 1 + src/client/inference_session.py | 22 +++++++++++++++++----- src/client/remote_block.py | 8 ++------ src/client/remote_generation.py | 10 ++++++++-- src/client/remote_sequential.py | 4 ++-- src/server/backend.py | 2 -- src/server/handler.py | 30 ++++++++++++++++++++++++------ src/server/server.py | 10 +++++++--- tests/test_block_exact_match.py | 9 ++++++++- tests/test_chained_calls.py | 2 +- tests/test_full_model.py | 2 +- 14 files changed, 88 insertions(+), 35 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 60cb02f..3cb25ad 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -81,7 +81,7 @@ jobs: python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 \ - --throughput 1 &> server1.log & + --throughput 1 --attn_cache_size 0.2GiB &> server1.log & SERVER1_PID=$! sleep 5 # wait for the first server to initialize DHT diff --git a/README.md b/README.md index aca399f..b59606d 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ loss = (outputs * torch.randn_like(outputs)).norm() loss.backward() # test inference, one block -with layer3.inference_session() as sess: +with layer3.inference_session(max_length=10) as sess: for i in range(10): res = sess.step(torch.ones(1, 1, 4096)) ``` diff --git a/cli/run_server.py b/cli/run_server.py index 1cdae73..03055f7 100644 --- a/cli/run_server.py +++ b/cli/run_server.py @@ -2,6 +2,7 @@ import configargparse from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.limits import increase_file_limit from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from humanfriendly import parse_size from src.server.server import Server @@ -32,16 +33,19 @@ def main(): parser.add_argument('--min_batch_size', type=int, default=1, help='Minimum required batch size for all expert operations') parser.add_argument('--max_batch_size', type=int, default=16384, - help='The total number of examples in the same batch will not exceed this value') + help='The total number of tokens in the same batch will not exceed this value') + parser.add_argument('--inference_max_length', type=int, default=16384, + help='Maximum total sequence length permitted per inference, defaults to 16384 tokens') parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') - parser.add_argument('--cache_size_bytes', type=int, default=None, - help='The size of memory cache for storing past attention keys/values between inference steps') parser.add_argument('--device', type=str, default=None, required=False, help='all experts will use this device in torch notation; default: cuda if available else cpu') parser.add_argument("--torch_dtype", type=str, default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") + parser.add_argument('--attn_cache_size', type=str, default=None, + help='The size of GPU memory allocated for storing past attention keys/values between inference' + ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB') parser.add_argument('--revision', type=str, default='main', help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models" "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") @@ -81,10 +85,17 @@ def main(): compression_type = args.pop("compression") compression = getattr(CompressionType, compression_type) + attn_cache_size = args.pop("attn_cache_size") + if attn_cache_size is not None: + attn_cache_size = parse_size(attn_cache_size) + assert isinstance( + attn_cache_size, (int, type(None)) + ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)" + use_auth_token = args.pop("use_auth_token") args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token - server = Server.create(**args, start=True, compression=compression) + server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size) try: server.join() diff --git a/requirements.txt b/requirements.txt index feccf05..afc0290 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ torch==1.12.0 accelerate==0.10.0 huggingface-hub==0.7.0 +humanfriendly https://github.com/learning-at-home/hivemind/archive/20b3b3d5f225ed525515a5383a008a8f9fad8173.zip https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip \ No newline at end of file diff --git a/src/client/inference_session.py b/src/client/inference_session.py index 94f6ffa..24852df 100644 --- a/src/client/inference_session.py +++ b/src/client/inference_session.py @@ -7,6 +7,7 @@ from typing import AsyncIterator, List, Optional import torch from hivemind import ( P2P, + MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, @@ -33,23 +34,32 @@ class RemoteTransformerBlockInferenceSession: :note: this inference session is *not* fault-tolerant out of the box """ - def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator): + def __init__( + self, + uid: ModuleUID, + rpc_info: RPCInfo, + inputs_queue: asyncio.Queue, + outputs_aiter: AsyncIterator, + *, + max_length: int, + ): self.uid, self.rpc_info = uid, rpc_info # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread; # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter + self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length)) self.stepped = False self.closed = False @classmethod async def _create( - cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None + cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata ) -> RemoteTransformerBlockInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" inputs_queue = asyncio.Queue() outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout) - return cls(uid, rpc_info, inputs_queue, outputs_stream) + return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata) @staticmethod async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator: @@ -73,6 +83,7 @@ class RemoteTransformerBlockInferenceSession: serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"])) ], + metadata=self._serialized_metadata if not self.stepped else None, ) ) ) @@ -121,13 +132,14 @@ class RemoteSequentialInferenceSession: An interface to a multi-step *inference* session for a sequence of remote transformer blocks """ - def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None): + def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata): self.sequence_manager = sequence_manager self.p2p = p2p self.closed = False self.chosen_spans: List[RemoteSpanInfo] = [] self.stack = contextlib.ExitStack() self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = [] + self.metadata = metadata self.timeout = timeout def __enter__(self): @@ -141,7 +153,7 @@ class RemoteSequentialInferenceSession: span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end]) inference_session = RemoteExpertWorker.run_coroutine( RemoteTransformerBlockInferenceSession._create( - stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout + stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata ) ) self.inference_sessions.append(inference_session) diff --git a/src/client/remote_block.py b/src/client/remote_block.py index 68cd004..7d0f920 100644 --- a/src/client/remote_block.py +++ b/src/client/remote_block.py @@ -33,12 +33,8 @@ class RemoteTransformerBlock(RemoteExpert): assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})" return super().forward(inputs) - def inference_session(self) -> RemoteTransformerBlockInferenceSession: + def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession: """Initialize a new inference session with the specified remote server""" return RemoteExpertWorker.run_coroutine( - RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info) + RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs) ) - - def begin_inference_session(self): - logger.warning("beging_inference_session was renamed to just inference_session") - return self.inference_session() diff --git a/src/client/remote_generation.py b/src/client/remote_generation.py index d20462a..1b1b5cd 100644 --- a/src/client/remote_generation.py +++ b/src/client/remote_generation.py @@ -60,14 +60,20 @@ class RemoteGenerationMixin: assert ( model_kwargs.get("stopping_criteria", None) is None ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria" + if inputs is not None: + assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]" + prefix_length = 0 if inputs is None else inputs.size(1) bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)" if max_length is not None and max_new_tokens is None: - max_new_tokens = max_length - inputs.size(1) + max_new_tokens = max_length - prefix_length assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}" + elif max_length is None and max_new_tokens is not None: + max_length = prefix_length + max_new_tokens if inputs is None: assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs" @@ -87,7 +93,7 @@ class RemoteGenerationMixin: provided_constraints=provided_constraints, ) - with self.transformer.h.inference_session() as sess: + with self.transformer.h.inference_session(max_length=max_length) as sess: outputs = [] if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]] diff --git a/src/client/remote_sequential.py b/src/client/remote_sequential.py index 86acfe1..86cca85 100644 --- a/src/client/remote_sequential.py +++ b/src/client/remote_sequential.py @@ -79,9 +79,9 @@ class RemoteSequential(nn.Module): def __len__(self): return len(self.sequence_manager) - def inference_session(self) -> RemoteSequentialInferenceSession: + def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession: self.sequence_manager.update_() - return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p) + return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs) def extra_repr(self) -> str: return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" diff --git a/src/server/backend.py b/src/server/backend.py index 6a883f6..9929770 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -14,8 +14,6 @@ from src.server.cache import MemoryCache use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) -MAX_LENGTH = 2048 - class InferenceTaskPool(TaskPool): def __init__(self, *args, **kwargs): diff --git a/src/server/handler.py b/src/server/handler.py index c50a8aa..27ed562 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -19,7 +19,7 @@ from hivemind.utils.asyncio import anext from hivemind.utils.streaming import split_for_streaming from src.data_structures import CHAIN_DELIMITER, ModuleUID -from src.server.backend import MAX_LENGTH, TransformerBackend +from src.server.backend import TransformerBackend from src.utils.misc import DUMMY, is_dummy @@ -28,10 +28,11 @@ class TransformerConnectionHandler(ConnectionHandler): module_backends: Dict[ModuleUID, TransformerBackend] - def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]): + def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend], inference_max_length: int): super().__init__(dht, module_backends) for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) + self.inference_max_length = inference_max_length async def rpc_inference( self, @@ -43,7 +44,15 @@ class TransformerConnectionHandler(ConnectionHandler): print("OPENED RPC_INFERENCE") request = await anext(requests) requested_uids = self._check_uids(request.uid) + metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + max_length = metadata.get("max_length") + + if not requested_uids: + raise ValueError("User must specify at least one block for inference, but got none") + assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}" + if not 0 <= max_length <= self.inference_max_length: + raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}") batch_size = request.tensors[0].size[0] if request.tensors else 1 @@ -52,10 +61,17 @@ class TransformerConnectionHandler(ConnectionHandler): ) # [cache_handle, prefix_length] prefix_length = 0 - async with self._allocate_caches(requested_backends, batch_size) as cache_handles: + async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles: assert len(cache_handles) == len(requested_backends) while request.tensors: # iterate while user is willing to supply tensors hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] + length_increment = hidden_states[0].shape[1] # how many tokens are added this step (in each seq) + + if prefix_length + length_increment > max_length: + raise ValueError( + f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" + f" exceeds pre-allocated maximum {max_length}" + ) # Cast inputs to backend dtype hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states] @@ -113,7 +129,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends) - assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3 + assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor" # Serialize the overall output serialized_output = [ @@ -193,7 +209,9 @@ class TransformerConnectionHandler(ConnectionHandler): return tuple(uids) @contextlib.asynccontextmanager - async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int) -> Sequence[int]: + async def _allocate_caches( + self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int + ) -> Sequence[int]: """Allocate memory caches for each transformer block, return cache handles""" async with contextlib.AsyncExitStack() as stack: handles = [] @@ -202,7 +220,7 @@ class TransformerConnectionHandler(ConnectionHandler): head_dim = backend.module.self_attention.head_dim cache_descriptor = TensorDescriptor( - size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype + size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype ) # [key_or_value, batch_size, max_length, num_heads, head_dim] diff --git a/src/server/server.py b/src/server/server.py index 057daa5..5d92bd9 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -36,6 +36,7 @@ class Server(threading.Thread): dht: DHT, module_backends: Dict[str, TransformerBackend], *, + inference_max_length: int, num_connection_handlers: int = 8, throughput: float, update_period: float = 30, @@ -47,7 +48,8 @@ class Server(threading.Thread): self.dht, self.module_backends = dht, module_backends self.throughput, self.update_period, self.expiration = throughput, update_period, expiration self.conn_handlers = [ - TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers) + TransformerConnectionHandler(dht, self.module_backends, inference_max_length) + for _ in range(num_connection_handlers) ] self.runtime = Runtime(self.module_backends, **kwargs) self.dht_handler_thread = ModuleAnnouncerThread( @@ -104,10 +106,11 @@ class Server(threading.Thread): num_handlers: int = 8, min_batch_size: int = 1, max_batch_size: int = 4096, + inference_max_length: int = 4096, torch_dtype: str = "auto", revision: str = "main", cache_dir: Optional[str] = None, - cache_size_bytes: Optional[int] = None, + attn_cache_size: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, initial_peers: Sequence[str] = (), compression=CompressionType.NONE, @@ -141,7 +144,7 @@ class Server(threading.Thread): logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") device = device or ("cuda" if torch.cuda.is_available() else "cpu") - memory_cache = MemoryCache(device, cache_size_bytes) + memory_cache = MemoryCache(device, attn_cache_size) assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: @@ -228,6 +231,7 @@ class Server(threading.Thread): blocks, throughput=throughput, num_connection_handlers=num_handlers, + inference_max_length=inference_max_length, device=device, stats_report_interval=stats_report_interval, update_period=update_period, diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index caac346..4761aea 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -4,6 +4,7 @@ import hivemind import pytest import torch import transformers +from hivemind import P2PHandlerError from test_utils import * from src.bloom.from_pretrained import load_pretrained_block @@ -27,9 +28,15 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3): (outputs_forward,) = remote_block(inputs) outputs_inference = [] - with remote_block.inference_session() as sess: + with remote_block.inference_session(max_length=inputs.shape[1]) as sess: for i in range(inputs.shape[1]): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) + + # test that max length is respected + with pytest.raises(P2PHandlerError) as exc_info: + sess.step(inputs[:, -1:, :]) + assert "Maximum length exceeded" in repr(exc_info.value) + outputs_inference = torch.cat(outputs_inference, dim=1) ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 84c4232..8148286 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -63,7 +63,7 @@ def test_chained_inference_exact_match(atol_inference=1e-4): inputs = torch.randn(1, 8, config.hidden_size) outputs_inference = [] - with remote_block.inference_session() as sess: + with remote_block.inference_session(max_length=inputs.shape[1]) as sess: for i in range(inputs.shape[1]): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) outputs_inference = torch.cat(outputs_inference, dim=1) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index 47f08be..b0ce824 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -31,7 +31,7 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3): embs = model.transformer.word_embeddings(test_inputs) embs = model.transformer.word_embeddings_layernorm(embs) recurrent_outputs = [] - with model.transformer.h.inference_session() as sess: + with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess: for t in range(embs.shape[1]): recurrent_outputs.append(sess.step(embs[:, t : t + 1, :])) recurrent_outputs = torch.cat(recurrent_outputs, dim=1) From 77220c718c6b2c4d888349f6961f94c8cf159081 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Wed, 31 Aug 2022 13:21:25 +0400 Subject: [PATCH 02/10] Add shallow prefix-tuned inference (#55) * Add prefix-tuned inference * Add prefix-tuned inference * Add preseq_length in prefix size --- src/client/remote_generation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/client/remote_generation.py b/src/client/remote_generation.py index 1b1b5cd..e4875cc 100644 --- a/src/client/remote_generation.py +++ b/src/client/remote_generation.py @@ -63,6 +63,7 @@ class RemoteGenerationMixin: if inputs is not None: assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]" prefix_length = 0 if inputs is None else inputs.size(1) + prefix_length += self.config.pre_seq_len bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id @@ -104,6 +105,9 @@ class RemoteGenerationMixin: hypo_ids = torch.arange(outputs[0].size(0)) while True: embs = self.transformer.word_embeddings(outputs[-1]) + if self.config.pre_seq_len > 0 and len(outputs) == 1: + prompts, _ = self.transformer.get_prompt(embs.size(0)) + embs = torch.cat([prompts, embs], dim=1) embs = self.transformer.word_embeddings_layernorm(embs) hidden_state = sess.step(embs)[:, -1] hidden_state = self.transformer.ln_f(hidden_state) From 0be21775af3ea3e540b069f05a2f137d9cc3e862 Mon Sep 17 00:00:00 2001 From: Pavel Samygin <44449246+GreenFatGuy@users.noreply.github.com> Date: Thu, 1 Sep 2022 04:26:31 +0300 Subject: [PATCH 03/10] remove transformer block, implement as sequential of size 1 (#54) * remove transformer block, implement as sequence size 1 * reimplement get_remote_module * fix readme Co-authored-by: Alexander Borzunov Co-authored-by: Aleksandr Borzunov --- README.md | 28 +++++------ src/client/__init__.py | 3 +- src/client/remote_block.py | 40 ---------------- src/client/remote_sequential.py | 30 +++++++++--- src/client/sequence_manager.py | 21 +++++---- src/client/sequential_autograd.py | 2 +- src/dht_utils.py | 77 +++++++++++++++++++------------ tests/test_block_exact_match.py | 12 ++--- tests/test_chained_calls.py | 31 +++++-------- 9 files changed, 115 insertions(+), 129 deletions(-) delete mode 100644 src/client/remote_block.py diff --git a/README.md b/README.md index b59606d..0e5547b 100644 --- a/README.md +++ b/README.md @@ -37,18 +37,18 @@ Then open a python notebook or console and run: ```python import torch import hivemind -from src import get_remote_module +from src import DistributedBloomConfig, get_remote_module dht = hivemind.DHT( initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS], # e.g. /ip4/127.0.0.1/... client_mode=True, start=True, ) - -layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4']) +config = DistributedBloomConfig.from_pretrained("bigscience/test-bloom-6b3") +layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'], config) assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT" # test forward/backward, two blocks -outputs, = layer4(*layer3(torch.randn(1, 64, 4096))) +outputs = layer4(layer3(torch.randn(1, 64, 4096))) loss = (outputs * torch.randn_like(outputs)).norm() loss.backward() @@ -74,18 +74,18 @@ python -m cli.convert_model --model bigscience/bloom-6b3 \ To test distributed inference, run one or more servers, then open a new shell and run pytest with environment variables: ```bash -# shell A: serve blocks 3 and 4 +# shell A: serve model python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \ - --block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337 + --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337 -# shell B: connect to the swarm and test individual blocks for exact match -export PYTHONPATH=. INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT" -BLOCK_UID=bigscience/test-bloomd-6b3.3 pytest tests/test_block_exact_match.py -BLOCK_UID=bigscience/test-bloomd-6b3.4 pytest tests/test_block_exact_match.py +# shell B: +export PYTHONPATH=. +export INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT" +export MODEL_NAME="bigscience/test-bloomd-6b3" -# the test below will fail because there is no server that serves layer 7 -# BLOCK_UID=bigscience/test-bloomd-6b3.7 pytest tests/test_block_exact_match.py +# test individual random blocks for exact match +pytest tests/test_block_exact_match.py -# test the full model (requires that servers collectively serve all model layers) -REF_NAME=bigscience/bloom-6b3 pytest tests/test_full_model.py +# test the full model +pytest tests/test_full_model.py ``` diff --git a/src/client/__init__.py b/src/client/__init__.py index 0335921..165de67 100644 --- a/src/client/__init__.py +++ b/src/client/__init__.py @@ -1,5 +1,4 @@ from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession -from src.client.remote_block import RemoteTransformerBlock from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel -from src.client.remote_sequential import RemoteSequential +from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock from src.client.sequence_manager import RemoteSequenceManager diff --git a/src/client/remote_block.py b/src/client/remote_block.py deleted file mode 100644 index 7d0f920..0000000 --- a/src/client/remote_block.py +++ /dev/null @@ -1,40 +0,0 @@ -# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me. -from __future__ import annotations - -import random - -import torch -from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker -from hivemind.moe.expert_uid import ExpertInfo -from hivemind.p2p import P2P, StubBase -from hivemind.utils import get_logger, use_hivemind_log_handler - -from src.client.inference_session import RemoteTransformerBlockInferenceSession -from src.data_structures import RemoteModuleInfo -from src.server.handler import TransformerConnectionHandler - -use_hivemind_log_handler("in_root_logger") -logger = get_logger(__file__) - - -class RemoteTransformerBlock(RemoteExpert): - """A class that interacts with a remote module on a specific server for forward/backward or inference""" - - def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P): - peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys()))) # TODO replace this - super().__init__(peer_info, p2p) - - @property - def stub(self) -> StubBase: - return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id) - - def forward(self, inputs: torch.Tensor, **kwargs): - for k, v in kwargs.items(): - assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})" - return super().forward(inputs) - - def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession: - """Initialize a new inference session with the specified remote server""" - return RemoteExpertWorker.run_coroutine( - RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs) - ) diff --git a/src/client/remote_sequential.py b/src/client/remote_sequential.py index 86cca85..d9e63b2 100644 --- a/src/client/remote_sequential.py +++ b/src/client/remote_sequential.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging from typing import Optional, Union import torch @@ -10,11 +9,9 @@ from torch import nn import src from src.client.inference_session import RemoteSequentialInferenceSession -from src.client.remote_block import RemoteTransformerBlock from src.client.sequence_manager import RemoteSequenceManager from src.client.sequential_autograd import _RemoteSequentialAutogradFunction from src.data_structures import UID_DELIMITER -from src.dht_utils import _create_remote_modules_from_infos from src.utils.misc import DUMMY use_hivemind_log_handler("in_root_logger") @@ -57,12 +54,16 @@ class RemoteSequential(nn.Module): outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) return outputs - def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]: + def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential: assert isinstance(ix, (int, slice)) if isinstance(ix, int): - assert 0 <= ix < len(self) - (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p) - return module + return RemoteTransformerBlock( + self.config, + self.dht, + dht_prefix=self.dht_prefix, + p2p=self.p2p, + sequence_manager=self.sequence_manager[ix], + ) else: return RemoteSequential( self.config, @@ -85,3 +86,18 @@ class RemoteSequential(nn.Module): def extra_repr(self) -> str: return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" + + +class RemoteTransformerBlock(RemoteSequential): + """Single transformer block hosted by swarm + + This class is deprecated and kept for backward compatibility. + It will be removed soon in favor of using ``RemoteSequential`` directly. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert len(self) == 1, "Remote Block is a sequence size 1" + + def extra_repr(self): + return f"{self.sequence_manager.block_uids[0]}" diff --git a/src/client/sequence_manager.py b/src/client/sequence_manager.py index 777f070..af552dd 100644 --- a/src/client/sequence_manager.py +++ b/src/client/sequence_manager.py @@ -82,6 +82,7 @@ class RemoteSequenceManager: for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)): if info is None: logger.warning(f"Found no block info for block {uid}") + continue if not isinstance(info, RemoteModuleInfo): logger.warning(f"Unexpected dht entry type for {uid}: {info}") if not info.servers: @@ -95,22 +96,24 @@ class RemoteSequenceManager: closed_spans = [] active_spans = {} for block_index, info in enumerate(block_infos): - for peer_id, server in info.servers.items(): - if server.state != ServerState.ONLINE: - continue - if peer_id not in active_spans: - active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id) - else: # peer_id in active_spans - active_spans[peer_id].end = block_index + 1 + if info is not None: + for peer_id, server in info.servers.items(): + if server.state != ServerState.ONLINE: + continue + if peer_id not in active_spans: + active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id) + else: # peer_id in active_spans + active_spans[peer_id].end = block_index + 1 for peer_id in list(active_spans.keys()): if ( - peer_id not in info.servers + info is None + or peer_id not in info.servers or info.servers[peer_id].state != ServerState.ONLINE or block_index == len(block_infos) - 1 ): closed_spans.append(active_spans.pop(peer_id)) - assert not active_spans + assert not active_spans, f"spans: {active_spans}" closed_spans.sort(key=lambda span: span.end - span.start, reverse=True) diff --git a/src/client/sequential_autograd.py b/src/client/sequential_autograd.py index 081194c..1498236 100644 --- a/src/client/sequential_autograd.py +++ b/src/client/sequential_autograd.py @@ -110,7 +110,7 @@ async def sequential_forward( If some subsequence fails, reconstructs the remaining path and tries to finish the forward. """ - assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 + assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" end_index = end_index if end_index is not None else len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) diff --git a/src/dht_utils.py b/src/dht_utils.py index fe5df32..78ef083 100644 --- a/src/dht_utils.py +++ b/src/dht_utils.py @@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Sequence, Union from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import P2P, PeerID +from hivemind.p2p import PeerID from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler import src @@ -72,34 +72,63 @@ async def _declare_active_modules( ) +def get_remote_sequence( + dht: DHT, + start: int, + stop: int, + config: src.DistributedBloomConfig, + dht_prefix: Optional[str] = None, + return_future: bool = False, +) -> Union[src.RemoteSequential, MPFuture]: + return RemoteExpertWorker.run_coroutine( + _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future + ) + + +async def _get_remote_sequence( + dht: DHT, + start: int, + stop: int, + config: src.DistributedBloomConfig, + dht_prefix: Optional[str] = None, +) -> src.RemoteSequential: + uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)] + p2p = await dht.replicate_p2p() + manager = src.RemoteSequenceManager(dht, uids, p2p) + return src.RemoteSequential(config, dht, dht_prefix, p2p, manager) + + def get_remote_module( dht: DHT, uid_or_uids: Union[ModuleUID, List[ModuleUID]], - expiration_time: Optional[DHTExpiration] = None, + config: src.DistributedBloomConfig, + dht_prefix: Optional[str] = None, return_future: bool = False, -) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]: +) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]: """ :param uid_or_uids: find one or more modules with these ids from across the DHT - :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time) + :param config: model config, usualy taken by .from_pretrained(MODEL_NAME) :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background. - :returns: a list of [RemoteTransformerBlock if found else None] + :returns: a list of [RemoteTransformerBlock] """ - single_uid = isinstance(uid_or_uids, ModuleUID) - uids = [uid_or_uids] if single_uid else uid_or_uids - infos = dht.run_coroutine( - partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future + return RemoteExpertWorker.run_coroutine( + _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future ) - if return_future: - - async def _unpack(infos_future: MPFuture, dht: DHT): - p2p = await dht.replicate_p2p() - modules = _create_remote_modules_from_infos(await infos_future, p2p) - return modules[0] if single_uid else modules - return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future) - p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) - modules = _create_remote_modules_from_infos(infos, p2p) +async def _get_remote_module( + dht: DHT, + uid_or_uids: Union[ModuleUID, List[ModuleUID]], + config: src.DistributedBloomConfig, + dht_prefix: Optional[str] = None, +) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]: + single_uid = isinstance(uid_or_uids, ModuleUID) + uids = [uid_or_uids] if single_uid else uid_or_uids + p2p = await dht.replicate_p2p() + managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids) + modules = [ + src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers + ] return modules[0] if single_uid else modules @@ -149,15 +178,3 @@ async def _get_remote_module_infos( if servers: modules[i] = RemoteModuleInfo(uid, servers) return modules - - -def _create_remote_modules_from_infos( - infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P -) -> List[Optional[src.RemoteTransformerBlock]]: - modules: List[Optional[src.RemoteTransformerBlock]] = [] - for info in infos: - if info is not None: - modules.append(src.RemoteTransformerBlock(info, p2p)) - else: - modules.append(None) - return modules diff --git a/tests/test_block_exact_match.py b/tests/test_block_exact_match.py index 4761aea..fad84ae 100644 --- a/tests/test_block_exact_match.py +++ b/tests/test_block_exact_match.py @@ -7,8 +7,10 @@ import transformers from hivemind import P2PHandlerError from test_utils import * +import src +from src import DistributedBloomConfig from src.bloom.from_pretrained import load_pretrained_block -from src.client.remote_block import RemoteTransformerBlock +from src.client.remote_sequential import RemoteTransformerBlock from src.data_structures import UID_DELIMITER from src.dht_utils import get_remote_module @@ -16,16 +18,14 @@ from src.dht_utils import get_remote_module @pytest.mark.forked def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = transformers.AutoConfig.from_pretrained(MODEL_NAME) + config = DistributedBloomConfig.from_pretrained(MODEL_NAME) for block_index in random.sample(range(config.n_layer), 3): - block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}" - remote_block = get_remote_module(dht, block_uid) - assert remote_block is not None, f"Could not find {block_uid} in DHT" + remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config) assert isinstance(remote_block, RemoteTransformerBlock) inputs = torch.randn(1, 8, config.hidden_size) - (outputs_forward,) = remote_block(inputs) + outputs_forward = remote_block(inputs) outputs_inference = [] with remote_block.inference_session(max_length=inputs.shape[1]) as sess: diff --git a/tests/test_chained_calls.py b/tests/test_chained_calls.py index 8148286..7cf6d44 100644 --- a/tests/test_chained_calls.py +++ b/tests/test_chained_calls.py @@ -7,25 +7,20 @@ import hivemind import pytest import torch -import transformers -from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo from test_utils import * +import src from src.bloom.from_pretrained import load_pretrained_block -from src.client.remote_block import RemoteTransformerBlock -from src.dht_utils import get_remote_module +from src.client.remote_sequential import RemoteSequential +from src.dht_utils import get_remote_sequence @pytest.mark.forked def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = transformers.AutoConfig.from_pretrained(MODEL_NAME) - remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0") - assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT" - assert isinstance(remote_block, RemoteTransformerBlock) - - _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info - remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id) + config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME) + remote_blocks = get_remote_sequence(dht, 3, 6, config) + assert isinstance(remote_blocks, RemoteSequential) ref_blocks = [ load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32), @@ -33,7 +28,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32), ] inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True) - outputs_rpc = remote_block.forward(inputs)[0] + outputs_rpc = remote_blocks.forward(inputs) outputs_rpc.sum().backward() grads_rpc = inputs.grad @@ -52,18 +47,14 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq @pytest.mark.forked def test_chained_inference_exact_match(atol_inference=1e-4): dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True) - config = transformers.AutoConfig.from_pretrained(MODEL_NAME) - remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0") - assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT" - assert isinstance(remote_block, RemoteTransformerBlock) - - _ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info - remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id) + config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME) + remote_blocks = get_remote_sequence(dht, 3, 5, config) + assert isinstance(remote_blocks, RemoteSequential) inputs = torch.randn(1, 8, config.hidden_size) outputs_inference = [] - with remote_block.inference_session(max_length=inputs.shape[1]) as sess: + with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess: for i in range(inputs.shape[1]): outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) outputs_inference = torch.cat(outputs_inference, dim=1) From 2eb58438524bcdb23d5f9fe785abc36cffb8b00d Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 1 Sep 2022 08:41:49 +0400 Subject: [PATCH 04/10] Update readme for the 1st public release (#57) --- README.md | 68 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 0e5547b..3f70ccf 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,59 @@ -# PETALS: Collaborative Inference of Large Models +

+
+ Decentralized platform for running 100B+ language models

+ + + + + + +

-Run BLOOM-176B, the largest open language model, by collaborating over the Internet. +## Key features -__[EARLY PROTOTYPE]__ - this project is a work in progress. Stuff breaks and gets fixed every day. Docs are nonexistent. -If you want us to wake you up when it's ready, click Watch -> Custom and tick "Releases". +- Run inference or fine-tune [BLOOM-176B](https://huggingface.co/bigscience/bloom) by joining compute resources with people all over the Internet. No need to have high-end GPUs. +- One inference step takes ≈ 1 sec — much faster than possible with offloading. Enough for chatbots and other interactive apps. +- Employ any fine-tuning and sampling methods by accessing model's hidden states and changing its control flow — something you can't do in proprietary APIs. -Roadmap: [__Issue #12__](https://github.com/learning-at-home/bloom-demo/issues/12) +

+ [Read paper] | [View website] +

-### Installation +## How it works? + +

+ +

+ +### 🚧 This project is in active development + +Be careful: some features may not work, interfaces may change, and we have no detailed docs yet (see [roadmap](https://github.com/bigscience-workshop/petals/issues/12)). + +A stable version of the code and a public swarm open to everyone will be released in November 2022. You can [subscribe](https://petals.ml/) to be emailed when it happens or fill in [this form](https://forms.gle/TV3wtRPeHewjZ1vH9) to help the public launch by donating GPU time. In the meantime, you can launch and use your own private swarm. + +## Code examples + +Solving a sequence classification task via soft prompt tuning of BLOOM-176B: + +```python +# Initialize distributed BLOOM with soft prompts +model = AutoModelForPromptTuning.from_pretrained( + "bigscience/distributed-bloom") +# Define optimizer for prompts and linear head +optimizer = torch.optim.AdamW(model.parameters()) + +for input_ids, labels in data_loader: + # Forward pass with local and remote layers + outputs = model.forward(input_ids) + loss = cross_entropy(outputs.logits, labels) + + # Distributed backward w.r.t. local params + loss.backward() # Compute model.prompts.grad + optimizer.step() # Update local params only + optimizer.zero_grad() +``` + +## Installation ```bash conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 @@ -16,7 +62,6 @@ pip install -r requirements.txt pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113 ``` - ### Basic functionality All tests is run on localhost @@ -89,3 +134,12 @@ pytest tests/test_block_exact_match.py # test the full model pytest tests/test_full_model.py ``` + +-------------------------------------------------------------------------------- + +

+ This project is a part of the BigScience research workshop. +

+

+ +

From 7653562aa12e76595467d39a7caff64c0408d64a Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 2 Sep 2022 15:38:04 +0400 Subject: [PATCH 05/10] Use latest version of Petals scheme, shrink Petals logo (#59) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3f70ccf..913d223 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@

-
+
Decentralized platform for running 100B+ language models

@@ -22,7 +22,7 @@ ## How it works?

- +

### 🚧 This project is in active development From 9bea7b9ea86614657adc871fd97074ddae74191f Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sat, 3 Sep 2022 06:38:18 +0400 Subject: [PATCH 06/10] Update bullet points with feedback from Tim and other people (#61) Co-authored-by: Tim Dettmers --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 913d223..6f0644f 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,10 @@ ## Key features -- Run inference or fine-tune [BLOOM-176B](https://huggingface.co/bigscience/bloom) by joining compute resources with people all over the Internet. No need to have high-end GPUs. -- One inference step takes ≈ 1 sec — much faster than possible with offloading. Enough for chatbots and other interactive apps. -- Employ any fine-tuning and sampling methods by accessing model's hidden states and changing its control flow — something you can't do in proprietary APIs. +- Run inference or fine-tune large language models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) by joining compute resources with people all over the Internet. No need to have high-end GPUs. +- It's difficult to fit the whole BLOOM-176B into GPU memory [unless](https://twitter.com/Tim_Dettmers/status/1559892918395031552) you have multiple high-end GPUs. Instead, **Petals** allows to load and serve a small part of the model, then team up with people serving all the other parts to run inference or fine-tuning. +- This way, one inference step takes ≈ 1 sec — much faster than possible with offloading. Enough for chatbots and other interactive apps. +- Beyond traditional language model APIs — you can employ any fine-tuning and sampling methods by executing custom paths through the model or accessing its hidden states. This allows for the comforts of an API with the flexibility of PyTorch.

[Read paper] | [View website] From 5f0c5329d4bf5a196ffe609f745b96510145bec9 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 5 Sep 2022 12:04:50 +0400 Subject: [PATCH 07/10] Update readme with arxiv link and more discussions (#62) Co-authored-by: justheuristic --- README.md | 58 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 6f0644f..f7be4c0 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ - Beyond traditional language model APIs — you can employ any fine-tuning and sampling methods by executing custom paths through the model or accessing its hidden states. This allows for the comforts of an API with the flexibility of PyTorch.

- [Read paper] | [View website] + [Read paper] | [View website]

## How it works? @@ -26,36 +26,60 @@

-### 🚧 This project is in active development +### Examples -Be careful: some features may not work, interfaces may change, and we have no detailed docs yet (see [roadmap](https://github.com/bigscience-workshop/petals/issues/12)). +Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library. -A stable version of the code and a public swarm open to everyone will be released in November 2022. You can [subscribe](https://petals.ml/) to be emailed when it happens or fill in [this form](https://forms.gle/TV3wtRPeHewjZ1vH9) to help the public launch by donating GPU time. In the meantime, you can launch and use your own private swarm. +This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a sequence classification task via soft prompt tuning: -## Code examples +```python +# Initialize distributed BLOOM and connect to the swarm +model = DistributedBloomForCausalLM.from_pretrained( + "bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW +) # Embeddings & prompts are on your device, BLOOM blocks are distributed -Solving a sequence classification task via soft prompt tuning of BLOOM-176B: +print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5)) -```python -# Initialize distributed BLOOM with soft prompts -model = AutoModelForPromptTuning.from_pretrained( - "bigscience/distributed-bloom") -# Define optimizer for prompts and linear head +# Training (updates only local prompts / adapters) optimizer = torch.optim.AdamW(model.parameters()) - for input_ids, labels in data_loader: - # Forward pass with local and remote layers outputs = model.forward(input_ids) loss = cross_entropy(outputs.logits, labels) - - # Distributed backward w.r.t. local params - loss.backward() # Compute model.prompts.grad - optimizer.step() # Update local params only optimizer.zero_grad() + loss.backward() + optimizer.step() ``` +### 🚧 This project is in active development + +Be careful: some features may not work, interfaces may change, and we have no detailed docs yet (see [roadmap](https://github.com/bigscience-workshop/petals/issues/12)). + +A stable version of the code and a public swarm open to everyone will be released in November 2022. You can [subscribe](https://petals.ml/) to be emailed when it happens or fill in [this form](https://forms.gle/TV3wtRPeHewjZ1vH9) to help the public launch by donating GPU time. In the meantime, you can launch and use your own private swarm. + +### 🔒 Privacy and security + +If you work with sensitive data, you should only use a private swarm (or a subset of servers in the public swarm) hosted by people and institutions you trust, who are authorized to process this data. + +This is important because it's technically possible for peers serving model layers to recover input data or model outputs. Also, if there are malicious peers, they may alter their outputs to influence the model outputs. See a more detailed discussion in Section 4 of our [paper](https://arxiv.org/pdf/2209.01188.pdf). + +## FAQ + +1. **What's the motivation for people to host model layers in the public swarm?** + + People who run inference and fine-tuning themselves get a certain speedup if they host a part of the model locally. Some may be also motivated to "give back" to the community helping them to run the model (similarly to how [BitTorrent](https://en.wikipedia.org/wiki/BitTorrent) users help others by sharing data they have already downloaded). + + Since it may be not enough for everyone, we are also working on introducing explicit __incentives__ ("bloom points") for people donating their GPU time to the public swarm. Once this system is ready, people who earned these points will be able to spend them on inference/fine-tuning with higher priority or increased security guarantees, or (maybe) exchange them for other rewards. + +2. **Why is the platform named "Petals"?** + + "Petals" is a metaphor for people serving different parts of the model. Together, they host the entire language model — [BLOOM](https://huggingface.co/bigscience/bloom). + + While our platform focuses on BLOOM now, we aim to support more [foundation models](https://arxiv.org/abs/2108.07258) in future. + ## Installation +__[To be updated soon]__ + ```bash conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 pip install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html From 54ad745bed9d49b625ff88bd2ce599c2037b6ec9 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 5 Sep 2022 15:05:59 +0400 Subject: [PATCH 08/10] Warn that current instructions involve 6B model but we will replace them soon (#63) --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f7be4c0..b8cd54b 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@

-### Examples +### 🛠️ Examples Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library. @@ -78,7 +78,9 @@ This is important because it's technically possible for peers serving model laye ## Installation -__[To be updated soon]__ +🚧 **Note:** These are short instructions for running a private swarm with a test 6B version of BLOOM. We will replace them with instructions involving the full 176B BLOOM and more detailed explanations soon (in a day or two). + +-------------------------------------------------------------------------------- ```bash conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32 From ada98a1b378f8a210afb34e048155be8ecbfc08b Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Tue, 6 Sep 2022 21:33:00 +0400 Subject: [PATCH 09/10] Add deep prompt inference (#66) Add deep prompt in inference_step. Small refactoring in deep prompt code. --- src/client/inference_session.py | 42 ++++++++++++++++++--- src/client/remote_generation.py | 5 ++- src/server/backend.py | 21 +++++++++-- src/server/handler.py | 67 +++++++++++++++++++-------------- 4 files changed, 95 insertions(+), 40 deletions(-) diff --git a/src/client/inference_session.py b/src/client/inference_session.py index 24852df..bb1455f 100644 --- a/src/client/inference_session.py +++ b/src/client/inference_session.py @@ -22,6 +22,7 @@ from hivemind.proto import runtime_pb2 from src.client.sequence_manager import RemoteSequenceManager from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from src.server.handler import TransformerConnectionHandler +from src.utils.misc import DUMMY, is_dummy use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -44,6 +45,7 @@ class RemoteTransformerBlockInferenceSession: max_length: int, ): self.uid, self.rpc_info = uid, rpc_info + self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread; # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue @@ -69,19 +71,43 @@ class RemoteTransformerBlockInferenceSession: if not next_input_message.uid and not next_input_message.tensors: break # this message means "done sending" - def step(self, new_hidden_states: torch.Tensor): - """Inference step: send a chunk of input tensors and receive a chunk of outputs""" + def step( + self, + new_hidden_states: torch.Tensor, + prompts: Optional[torch.Tensor] = None, + hypo_ids: Optional[torch.Tensor] = None, + ): + """ + Inference step: send a chunk of input tesors and receive a chunk of outputs + :prompts: optional DEEP prompts, added to a prefix of each layer's outputs, + if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size] + """ if self.closed: raise Exception("Session is closed, cannot perform step") + if prompts is None or is_dummy(prompts): + prompts = DUMMY + else: + assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]" + assert prompts.shape[0] == self.num_blocks + assert prompts.shape[1] in (new_hidden_states.shape[0], 1) + assert prompts.shape[2] <= new_hidden_states.shape[1] + assert prompts.shape[3] == new_hidden_states.shape[2] + + if hypo_ids is None or is_dummy(hypo_ids): + hypo_ids = DUMMY + else: + assert len(hypo_ids) == len(new_hidden_states) + assert hypo_ids.dtype == torch.int64 + # serialize inputs and put them into the queue - inputs = (new_hidden_states,) + inputs = (new_hidden_states, prompts, hypo_ids) outputs_serialized = RemoteExpertWorker.run_coroutine( self._step( runtime_pb2.ExpertRequest( uid=self.uid, tensors=[ serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"])) + for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"])) ], metadata=self._serialized_metadata if not self.stepped else None, ) @@ -161,12 +187,16 @@ class RemoteSequentialInferenceSession: return self - def step(self, inputs: torch.Tensor): + def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs): assert not self.closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") + if prompts is None or is_dummy(prompts): + prompts = DUMMY + else: + assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager) for session in self.inference_sessions: - outputs = session.step(inputs) + outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs) assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}" inputs = outputs return inputs diff --git a/src/client/remote_generation.py b/src/client/remote_generation.py index e4875cc..d2be2c9 100644 --- a/src/client/remote_generation.py +++ b/src/client/remote_generation.py @@ -105,11 +105,12 @@ class RemoteGenerationMixin: hypo_ids = torch.arange(outputs[0].size(0)) while True: embs = self.transformer.word_embeddings(outputs[-1]) + intermediate_prompts = None if self.config.pre_seq_len > 0 and len(outputs) == 1: - prompts, _ = self.transformer.get_prompt(embs.size(0)) + prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0)) embs = torch.cat([prompts, embs], dim=1) embs = self.transformer.word_embeddings_layernorm(embs) - hidden_state = sess.step(embs)[:, -1] + hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1] hidden_state = self.transformer.ln_f(hidden_state) lm_logits = self.lm_head(hidden_state) diff --git a/src/server/backend.py b/src/server/backend.py index 9929770..27ee1ad 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -1,15 +1,16 @@ """Code for serving bloom blocks via hivemind-server""" from queue import Empty -from typing import Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import torch -from hivemind import use_hivemind_log_handler +from hivemind import BatchTensorDescriptor, use_hivemind_log_handler from hivemind.moe.server.module_backend import ModuleBackend from hivemind.moe.server.task_pool import TaskPool from hivemind.utils import InvalidStateError, get_logger from src.bloom.from_pretrained import BloomBlock from src.server.cache import MemoryCache +from src.utils.misc import is_dummy use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -55,18 +56,28 @@ class TransformerBackend(ModuleBackend): self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference" ) self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype + self.inference_schema = ( + ( + *self.args_schema, + BatchTensorDescriptor((), dtype=self.dtype), + BatchTensorDescriptor((), dtype=torch.int64), + ), + self.kwargs_schema, + ) def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: with torch.inference_mode(): attention_cache_handle = int(cache_metadata[0, 0].item()) prefix_length = int(cache_metadata[0, 1].item()) - hidden_states = inputs[0] # todo: in future, it would be best to support attention mask here + (hidden_states, hypo_ids) = inputs assert ( hidden_states.ndim == 3 ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" with self.memory_cache.use_cache(attention_cache_handle) as cache: assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5 + if not is_dummy(hypo_ids): + cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length] print("METADATA:", cache_metadata, past_k.shape, past_v.shape) hidden_states, (new_k, new_v) = self.module.forward( @@ -85,3 +96,7 @@ class TransformerBackend(ModuleBackend): def get_pools(self) -> Sequence[TaskPool]: return self.forward_pool, self.backward_pool, self.inference_pool + + def get_info(self) -> Dict[str, Any]: + """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.""" + return dict(super().get_info(), inference_schema=self.inference_schema) diff --git a/src/server/handler.py b/src/server/handler.py index 27ed562..b2e15f7 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -64,41 +64,56 @@ class TransformerConnectionHandler(ConnectionHandler): async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles: assert len(cache_handles) == len(requested_backends) while request.tensors: # iterate while user is willing to supply tensors - hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] - length_increment = hidden_states[0].shape[1] # how many tokens are added this step (in each seq) + hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors] + # Cast inputs to backend dtype + hidden_states = hidden_states.to(requested_backends[0].dtype) + assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}" + + # parse deep prompts (optional argument) + if prompts is None or is_dummy(prompts) or is_dummy(prompts): + prompts = [DUMMY] * len(requested_backends) + else: + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + + if not (len(requested_backends) == len(prompts)): + raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends") + + length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq) if prefix_length + length_increment > max_length: raise ValueError( f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}" f" exceeds pre-allocated maximum {max_length}" ) - # Cast inputs to backend dtype - hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states] - # run request tensors through all requested modules, update caches - for backend, cache_handle in zip(requested_backends, cache_handles): + for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles): + if not is_dummy(prompt): + hidden_states[:, : prompt.shape[1]] += prompt + cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length + assert isinstance( + hidden_states, torch.Tensor + ), f"hidden states must be tensor, got {type(hidden_states)}" assert ( - len(hidden_states) == 1 and hidden_states[0].ndim == 3 + hidden_states.ndim == 3 ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - - hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states) - assert isinstance(hidden_states, (list, tuple)) - assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 + (hidden_states,) = await backend.inference_pool.submit_task( + cache_metadata, hidden_states, hypo_ids + ) # serialize and send last layer outputs yield runtime_pb2.ExpertResponse( tensors=[ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) for result, proto in zip( - hidden_states, nested_flatten(requested_backends[-1].outputs_schema) + (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema) ) ] ) # prepare for next step - prefix_length += hidden_states[0].shape[1] + prefix_length += hidden_states.shape[1] request = await (anext(requests)) finally: print("CLOSED RPC_INFERENCE") @@ -238,23 +253,20 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass :returns: hidden states after the last layer [batch_size, seq_length, hid_size] """ - hidden_states, *prompts = flat_tensors + hidden_states, prompts = flat_tensors dtype = requested_backends[0].dtype # check parse input tensors and cast dtypes hidden_states = hidden_states.to(dtype) assert hidden_states.ndim == 3 - if not prompts or is_dummy(prompts[0]): + if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) - pre_seq_len = 0 else: - prompts = [prompts[0].to(requested_backends[0].dtype)] - prompts = [p.squeeze(0) for p in prompts[0].split(1)] - pre_seq_len = prompts[0].shape[-2] + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a chain of requested backends for backend, prompt in zip(requested_backends, prompts): if not is_dummy(prompt): - hidden_states[:, :pre_seq_len] += prompt + hidden_states[:, : prompt.shape[1]] += prompt (hidden_states,) = await backend.forward_pool.submit_task(hidden_states) assert isinstance(hidden_states, torch.Tensor) assert ( @@ -268,18 +280,15 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence async def _rpc_backward( *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend] ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - inputs, grad_outputs, *prompts = flat_tensors + inputs, grad_outputs, prompts = flat_tensors # Cast inputs & grad outputs to backend dtype inputs = inputs.to(requested_backends[0].dtype) grad_outputs = grad_outputs.to(requested_backends[-1].dtype) - if not prompts or is_dummy(prompts[0]): + if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) - pre_seq_len = 0 else: - prompts = [prompts[0].to(requested_backends[0].dtype)] - prompts = [p.squeeze(0) for p in prompts[0].split(1)] - pre_seq_len = prompts[0].shape[-2] + prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a forward chain to collect intermediate inputs # Note that we do not forward for the last module since we do not need its output @@ -287,13 +296,13 @@ async def _rpc_backward( for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" if not is_dummy(prompt): - inputs[:, :pre_seq_len] += prompt + inputs[:, : prompt.shape[1]] += prompt inter_inputs.append(inputs) (inputs,) = await backend.forward_pool.submit_task(inputs) assert isinstance(inputs, torch.Tensor) if not is_dummy(prompts[-1]): - inputs[:, :pre_seq_len] += prompts[-1] + inputs[:, : prompts[-1].shape[1]] += prompts[-1] inter_inputs.append(inputs) assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" @@ -303,7 +312,7 @@ async def _rpc_backward( (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs) assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): - grad_prompts_reversed.append(grad_outputs[:, :pre_seq_len].unsqueeze(0)) + grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape From 8a0c0569299a6d8a6476e10f43d91e58451049cb Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 7 Sep 2022 01:41:23 +0300 Subject: [PATCH 10/10] Fix calling rpc_info multiple times (#60) call info once --- src/client/sequence_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/client/sequence_manager.py b/src/client/sequence_manager.py index af552dd..c05ae72 100644 --- a/src/client/sequence_manager.py +++ b/src/client/sequence_manager.py @@ -141,6 +141,7 @@ class RemoteSequenceManager: stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0])) ) self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info) + break except Exception as e: retries += 1 if retries >= self.max_retries: