pull/475/head
artek0chumak 9 months ago
commit d1f0a320c7

@ -41,6 +41,7 @@ jobs:
pip install .[dev]
- name: Test
run: |
set -x # Print executed commands
export MODEL_NAME="${{ matrix.model }}"
export REF_NAME="${{ matrix.model }}"
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"

@ -36,7 +36,7 @@ install_requires =
accelerate>=0.20.3,<0.21.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers>=4.31.0,<5.0.0
transformers>=4.31.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.9
@ -46,7 +46,7 @@ install_requires =
cpufeature>=0.2.0
packaging>=20.9
sentencepiece>=0.1.99
peft>=0.4.0
peft==0.4.0
safetensors>=0.3.1
Dijkstar>=2.6.0

@ -1,4 +1,4 @@
from petals.client.config import ClientConfig
from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
from petals.client.routing.sequence_manager import RemoteSequenceManager
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
from petals.client.routing import NoSpendingPolicy, RemoteSequenceManager, SpendingPolicyBase

@ -0,0 +1,31 @@
import dataclasses
from typing import Optional, Sequence, Union
from hivemind import PeerID
from petals.constants import PUBLIC_INITIAL_PEERS
@dataclasses.dataclass
class ClientConfig:
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, send requests only to these servers
blocked_servers: Optional[Sequence[Union[PeerID, str]]] = None # if defined, do not use these servers
use_server_to_server: bool = True # Use direct server-to-server communication
connect_timeout: float = 5 # timeout for opening a connection
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 3 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update

@ -3,7 +3,7 @@ import json
import os
import re
import tempfile
import threading
from contextvars import ContextVar
from typing import List, Optional, Tuple, Union
import torch
@ -47,18 +47,16 @@ class FromPretrainedMixin:
)
_shard_config = threading.local()
_shard_config.ignored_keys = None
_ignored_keys = ContextVar("ignored_keys", default=None)
@contextlib.contextmanager
def ignore_keys(patterns: List[str]):
token = _ignored_keys.set(patterns)
try:
prev_patterns = _shard_config.ignored_keys
_shard_config.ignored_keys = patterns
yield
finally:
_shard_config.ignored_keys = prev_patterns
_ignored_keys.reset(token)
def patched_get_checkpoint_shard_files(
@ -66,7 +64,7 @@ def patched_get_checkpoint_shard_files(
) -> Tuple[List[str], dict]:
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
should_ignore_keys = _shard_config.ignored_keys is not None
should_ignore_keys = _ignored_keys.get() is not None
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
with tempdir_ctx as tempdir:
if should_ignore_keys:
@ -77,7 +75,7 @@ def patched_get_checkpoint_shard_files(
index["weight_map"] = {
param_name: filename
for param_name, filename in index["weight_map"].items()
if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get())
}
n_loaded_shards = len(set(index["weight_map"].values()))
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")

@ -7,22 +7,18 @@ import uuid
from typing import AsyncIterator, List, Optional, Tuple
import torch
from hivemind import (
MSGPackSerializer,
anext,
deserialize_torch_tensor,
get_logger,
nested_flatten,
serialize_torch_tensor,
)
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P
from hivemind.proto import runtime_pb2
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig, maybe_log_traceback
from petals.client.config import ClientConfig
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
@ -36,7 +32,7 @@ class _ServerInferenceSession:
def __init__(
self,
config: SequenceManagerConfig,
config: ClientConfig,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
@ -63,7 +59,7 @@ class _ServerInferenceSession:
@classmethod
async def create(
cls,
config: SequenceManagerConfig,
config: ClientConfig,
p2p: P2P,
span: RemoteSpanInfo,
uid: ModuleUID,
@ -128,13 +124,13 @@ class _ServerInferenceSession:
assert prompts.shape[3] == inputs.shape[2]
if hypo_ids is None or is_dummy(hypo_ids):
hypo_ids = DUMMY
hypo_ids = DUMMY_INT64
else:
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64
# serialize inputs and put them into the queue
input_tensors = (inputs, prompts, hypo_ids)
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
@ -144,13 +140,25 @@ class _ServerInferenceSession:
if next_servers:
request_metadata["next_servers"] = next_servers
request_metadata["args_structure"] = args_structure
# TODO: make possible to use different compression method for different tensors
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
compression = server_side_inference_schema[0].compression
inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
# TODO: create more explicit way to check servers schema and client's structure
assert len(input_tensors) >= len(
server_side_inference_schema
), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(input_tensors, nested_flatten(self.rpc_info["inference_schema"]))
for tensor, proto in zip(input_tensors, inference_schema)
],
metadata=MSGPackSerializer.dumps(request_metadata),
)
@ -222,7 +230,7 @@ class InferenceSession:
self._server_sessions = []
self._position = 0
self._max_length = max_length
self.last_token_id = None
self.output_ids = None
@property
def num_blocks(self) -> int:
@ -369,3 +377,13 @@ class InferenceSession:
def __del__(self):
self.close()
@property
def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0
return self.output_ids[:, -1:] if self.output_ids is not None else None
@last_token_id.setter
def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0
if self.output_ids is None:
raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet")
self.output_ids[:, -1:] = value

@ -70,8 +70,8 @@ class LMHead(nn.Module):
if not self._bf16_warning_shown:
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
logger.warning(
"Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
"Consider loading the model with torch_dtype='float32'"
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
)
self._bf16_warning_shown = True

@ -76,9 +76,9 @@ def force_non_empty_weights():
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
"""
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
try:
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
yield
finally:
nn.Module.register_parameter = possibly_patched_register_parameter

@ -12,13 +12,14 @@ from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_U
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.client.config import ClientConfig
from petals.data_structures import ModuleUID, RPCInfo
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
@ -28,7 +29,7 @@ async def _forward_unary(
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
@ -38,7 +39,7 @@ async def _backward_unary(
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
@ -51,7 +52,7 @@ async def _forward_stream(
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
@ -68,7 +69,7 @@ async def run_remote_forward(
stub: StubBase,
rpc_info: RPCInfo,
*inputs: torch.Tensor,
config: SequenceManagerConfig,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
@ -84,26 +85,20 @@ async def run_remote_forward(
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (inputs, kwargs)
# Modify forward_schema to support prompts
forward_inputs = tuple(nested_flatten((inputs, kwargs)))
args_schema, kwargs_schema = rpc_info["forward_schema"]
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
if not nested_compare(forward_inputs, forward_schema_with_prompts):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
forward_inputs = nested_flatten(forward_inputs)
compression = args_schema[0].compression
forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
# TODO: create more explicit way to check servers schema and client's structure
assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
for tensor, proto in zip(inputs, forward_schema)
)
)
@ -119,10 +114,8 @@ async def run_remote_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
inputs: torch.Tensor,
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
config: SequenceManagerConfig,
*inputs_and_grad_outputs: torch.Tensor,
config: ClientConfig,
metadata: Optional[bytes] = None,
**kwargs,
) -> Sequence[torch.Tensor]:
@ -131,16 +124,14 @@ async def run_remote_backward(
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
# TODO generalize this
prompts_schema = next(iter(args_schema))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
outputs_schema = rpc_info["outputs_schema"]
compression = args_schema[0].compression
backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs)
# TODO: create more explicit way to check servers schema and client's structure
assert (
len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1
), "Inputs, grad_outputs and prompt tensors are necessary for a backward step"
# Asynchronous serialization
loop = asyncio.get_running_loop()

@ -1,349 +1,142 @@
import contextlib
from typing import List, Optional
import dataclasses
from contextvars import ContextVar
from typing import ContextManager, List, Optional
import torch
import transformers
from hivemind.utils.logging import get_logger
from transformers.generation.utils import ModelOutput
from petals.client.inference_session import InferenceSession
from petals.utils.generation_algorithms import (
BeamSearchAlgorithm,
DecodingAlgorithm,
GreedyAlgorithm,
NucleusAlgorithm,
SamplingAlgorithm,
TopKAlgorithm,
)
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
from petals.client.remote_sequential import RemoteSequential
from petals.utils.misc import DUMMY, docstring_from
logger = get_logger(__name__)
class RemoteGenerationMixin:
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
The class exposes can be used for:
- *greedy decoding*.
- *multinomial, top-k and top-p sampling*.
- *beam-search decoding*
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
However, it has some differences for remote usage.
"""
def inference_session(self, **kwargs) -> InferenceSession:
"""
Returns an inference session for the model's RemoteSequential module.
@dataclasses.dataclass(frozen=True)
class RemotePastKeyValues:
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
:param max_length: Maximal expected length of inference results. Servers use this parameter
to calculate the size of attention caches allocated to this client.
"""
hypo_ids: Optional[torch.LongTensor] = None
return self.transformer.h.inference_session(**kwargs)
def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
@torch.inference_mode()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
*,
do_sample: Optional[bool] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
num_beams: Optional[int] = 1,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
max_length: Optional[int] = None,
max_new_tokens: Optional[int] = None,
decoding_algorithm: Optional[DecodingAlgorithm] = None,
provided_constraints: List[ABCBloomConstraint] = [],
num_return_sequences: Optional[int] = None,
session: Optional[InferenceSession] = None,
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head.
:param inputs: The input tokens to the model.
:param do_sample: Whether to sample from the model predictions or take the argmax.
:param temperature: The temperature to use for sampling.
:param top_k: The number of results to return.
:param top_p: The cumulative probability of results to return.
:param num_beams: The number of beams to use for beam search.
:param bos_token_id: The id of the beginning of sentence token.
:param eos_token_id: The id of the end of sentence token.
:param pad_token_id: The id of the padding token.
:param max_length: The maximum number of tokens in the output (including input tokens).
:param max_new_tokens: The maximum number of tokens to generate.
:param decoding_algorithm: The decoding algorithm to use.
:param provided_constraints: A list of constraints to use.
:param num_return_sequences: How many hypothesis from the beam will be in output.
"""
_skipped_tokens = ContextVar("skipped_tokens", default=0)
prefix_length = 0 if inputs is None else inputs.size(1)
prefix_length += self.config.pre_seq_len
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
class _SkipTokensMixin:
# This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
# due to how transformers.PreTrainedModel.can_generate() works
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
input_ids = input_ids[:, _skipped_tokens.get() :]
_skipped_tokens.set(0)
return super().prepare_inputs_for_generation(input_ids, **kwargs)
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
if max_length is not None and max_new_tokens is None:
max_new_tokens = max_length - prefix_length
assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
elif max_length is None and max_new_tokens is not None:
max_length = prefix_length + max_new_tokens
resuming_session = session is not None and session.last_token_id is not None
if num_beams > 1 and resuming_session:
raise NotImplementedError(
"Resuming inference session in .generate() along with beam search is not supported yet"
)
class RemoteGenerationMixin(_SkipTokensMixin):
"""
This class is an upgrade to `transformers.GenerationMixin` that:
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
you don't have to rerun the prefix through all the servers to generate each new token
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
"""
if inputs is not None:
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
if resuming_session:
inputs = torch.cat([session.last_token_id, inputs], dim=1)
else:
if resuming_session:
inputs = session.last_token_id
else:
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
batch_size = inputs.size(0)
@docstring_from(RemoteSequential.active_session)
@property
def active_session(self) -> Optional[InferenceSession]:
return self.transformer.h.active_session
if decoding_algorithm is None:
if do_sample:
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
elif num_beams is not None and num_beams > 1:
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
else:
if top_k is not None or top_p is not None:
logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling")
decoding_algorithm = GreedyAlgorithm()
@docstring_from(RemoteSequential.use_session)
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
return self.transformer.h.use_session(session)
if num_beams > 1:
inputs = torch.cat([inputs] * num_beams, dim=0)
if batch_size > 1:
# TODO: resolve padding problem
logger.warning(
f"You set batch_size {batch_size} within beam search generation. "
f"Be careful, results on sequences with different length may be padded wrong way"
)
@docstring_from(RemoteSequential.inference_session)
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
return self.transformer.h.inference_session(**kwargs)
if num_return_sequences is None:
num_return_sequences = 1
@docstring_from(transformers.GenerationMixin.generate.__doc__)
def generate(
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
):
self._fix_generate_kwargs(kwargs)
if session is not None:
# If a session specified explicitly, use it
context_manager = self.use_session(session)
elif self.active_session is not None:
# If there's an active session, don't do anything
context_manager = contextlib.nullcontext(self.active_session)
else:
# If there's no active session, create a new one
assert num_return_sequences <= num_beams, (
f"You want more sequences than the beam has."
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
)
max_length = kwargs.get("max_length")
max_new_tokens = kwargs.get("max_new_tokens")
assert (max_length is None) != (
max_new_tokens is None
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
constraints = self._get_constraints(
inputs=inputs,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
provided_constraints=provided_constraints,
)
if max_length is not None:
session_max_length = max_length
else:
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
if session is None:
context_manager = self.inference_session(max_length=max_length)
else:
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
with context_manager as session:
outputs = []
# Find samples with padded inputs.
# They will be changed before all of the samples have right length.
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
# Prepend the tokens from the previous .generate() call
n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
if n_prev_tokens > 0:
if kwargs.get("num_beams", 1) > 1:
logger.warning(
"Beam search will not work properly in the resumed petals.InferenceSession "
"since intermediate beam entries are lost"
)
if inputs is not None:
inputs = torch.cat([session.output_ids, inputs], dim=1)
else:
inputs = session.output_ids
# Don't actually run all previous tokens through the transformer,
# but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
_skipped_tokens.set(max(0, n_prev_tokens - 1))
result = super().generate(inputs, *args, **kwargs)
sequences = result.sequences if isinstance(result, ModelOutput) else result
# Save tokens from this .generate() call
session.output_ids = sequences
# Crop the last tokens from the previous call
sequences = sequences[:, n_prev_tokens:].clone()
if isinstance(result, ModelOutput):
result.sequences = sequences
else:
outputs += [inputs]
last_token_id = None
seq_idx = outputs[0].size(1)
hypo_ids = torch.arange(outputs[0].size(0))
while True:
hidden_state = self.transformer.word_embeddings(outputs[-1])
intermediate_prompts = None
if self.config.pre_seq_len > 0 and len(outputs) == 1:
prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
hidden_state = torch.cat([prompts, hidden_state], dim=1)
hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
result = sequences
hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
return result
hidden_state = self.transformer.ln_f(hidden_state)
lm_logits = self.lm_head(hidden_state)
@staticmethod
def _fix_generate_kwargs(kwargs: dict) -> dict:
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
if "max_length" in kwargs and kwargs["max_length"] is None:
del kwargs["max_length"]
for constraint in constraints:
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
do_sample = kwargs.get("do_sample")
if isinstance(do_sample, int):
kwargs["do_sample"] = bool(do_sample)
# If some samples were padded, change only these samples
if seq_idx < inputs.size(1):
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
last_token_id = (~pad_token_mask) * inputs[
:, seq_idx : seq_idx + 1
] + pad_token_mask * last_token_id
# TODO: refactor outputs
if num_beams > 1:
for i in range(len(outputs), 1, -1):
outputs[i - 1] = outputs[i - 1][hypo_ids]
outputs.append(last_token_id)
session.last_token_id = last_token_id
seq_idx += 1
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
break
outputs = torch.cat(outputs, dim=-1)
if resuming_session:
outputs = outputs[:, 1:]
if num_beams > 1:
pre_return_idx = [
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
]
return_idx = torch.cat(pre_return_idx, dim=0)
outputs = outputs[return_idx]
return outputs
def greedy_search(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
:param input_ids: The input tokens to the model.
:param max_length: The maximum length of the sequence to generate.
:param pad_token_id: The id of the padding token.
:param eos_token_id: The id of the end of sentence token.
:param provided_constraints: A list of constraints to use.
"""
return self.generate(
inputs=input_ids,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=GreedyAlgorithm(),
provided_constraints=provided_constraints,
)
def sample(
self,
input_ids: torch.LongTensor,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
:param: input_ids: The input tokens to the model.
:param: temperature: The temperature to use for sampling.
:param: top_k: The number of samples to use for top_k sampling.
:param: top_p: The probability of using top_p sampling.
:param: max_length: The maximum length of the sequence to generate.
:param: pad_token_id: The id of the padding token.
:param: eos_token_id: The id of the end of sentence token.
:param: provided_constraints: A list of constraints to use.
"""
return self.generate(
inputs=input_ids,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
provided_constraints=provided_constraints,
)
def beam_search(
self,
input_ids: torch.LongTensor,
num_beams: int = 1,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
"""
Generates sequences of token ids for models with a language modeling head. Uses beam search.
:param input_ids: The input tokens to the model.
:param num_beams: The number of beams to use.
:param max_length: The maximum length of the sequence to generate.
:param pad_token_id: The id of the padding token.
:param eos_token_id: The id of the end of sentence token.
:param provided_constraints: A list of constraints to use.
"""
decoding_algorithm = BeamSearchAlgorithm(
num_beams=num_beams,
batch_size=input_ids.size(0),
)
return self.generate(
inputs=input_ids,
num_beams=num_beams,
max_new_tokens=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoding_algorithm=decoding_algorithm,
provided_constraints=provided_constraints,
)
def beam_sample(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
raise NotImplementedError
def group_beam_search(
self,
input_ids: torch.LongTensor,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> torch.LongTensor:
raise NotImplementedError
def _choose_sample_algorithm(
self,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> DecodingAlgorithm:
if (top_k is not None) and (top_p is not None):
raise ValueError("You have to provide only top_k or top_p for sampling")
if top_k is not None:
return TopKAlgorithm(top_k, temperature)
elif top_p is not None:
return NucleusAlgorithm(top_p, temperature)
else:
return SamplingAlgorithm(temperature)
return kwargs
def _get_constraints(
self,
inputs: Optional[torch.Tensor] = None,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
provided_constraints: List[ABCBloomConstraint] = [],
) -> List[ABCBloomConstraint]:
constraints = []
constraints.extend(provided_constraints)
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
return constraints
@staticmethod
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)

@ -1,16 +1,18 @@
from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Optional, Union
import torch
from hivemind import DHT, get_logger
from torch import nn
from petals.client.config import ClientConfig
from petals.client.inference_session import InferenceSession
from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig
from petals.client.routing import RemoteSequenceManager
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
logger = get_logger(__name__)
@ -22,7 +24,7 @@ class RemoteSequential(nn.Module):
def __init__(
self,
config: SequenceManagerConfig,
config: ClientConfig,
*,
sequence_manager: Optional[RemoteSequenceManager] = None,
dht: Optional[DHT] = None,
@ -45,11 +47,52 @@ class RemoteSequential(nn.Module):
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
self.sequence_manager = sequence_manager
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
self._active_session = ContextVar("active_session", default=None)
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version"
outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
return outputs
if self.active_session is None:
assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
else:
return self.active_session.step(inputs, prompts, **kwargs)
@property
def active_session(self) -> Optional[InferenceSession]:
"""
If called inside `with model.inference_session(...):` or `with model.use_session(...):`,
returns an active InferenceSession. Otherwise, returns None.
"""
return self._active_session.get()
@property
def position(self) -> int:
"""Returns the prefix length (in tokens) in the active inference session or zero if no session is active."""
return self.active_session.position if self.active_session is not None else 0
@contextmanager
def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:
"""Inside this context, forward() will use an _existing_ InferenceSession provided as the argument."""
token = self._active_session.set(session)
try:
yield session
finally:
self._active_session.reset(token)
@contextmanager
def inference_session(self, **kwargs) -> InferenceSession:
"""
Inside this context, forward() will use a _new_ InferenceSession created with given parameters.
:param max_length: Maximal expected length of inference results. Servers use this parameter
to calculate the size of attention caches allocated to this client.
"""
with InferenceSession(self.sequence_manager, **kwargs) as session, self.use_session(session):
yield session
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
return RemoteSequential(
@ -64,8 +107,5 @@ class RemoteSequential(nn.Module):
def __len__(self):
return len(self.sequence_manager)
def inference_session(self, **kwargs) -> InferenceSession:
return InferenceSession(self.sequence_manager, **kwargs)
def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

@ -1 +1,2 @@
"""Client-side functions responsible for choosing the best server, """
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

@ -7,7 +7,8 @@ import logging
import random
import threading
import time
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
import warnings
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from weakref import WeakMethod
import dijkstar
@ -18,40 +19,27 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger
import petals.dht_utils
from petals.client.config import ClientConfig
from petals.client.routing.sequence_info import RemoteSequenceInfo
from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
from petals.utils.dht import get_remote_module_infos
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
logger = get_logger(__name__)
@dataclasses.dataclass
class SequenceManagerConfig:
initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS) # a list of initial peers for hivemind DHT
dht_prefix: Optional[str] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout: int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
use_server_to_server: bool = True # Use direct server-to-server communication
connect_timeout: float = 5 # timeout for opening a connection
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 3 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update
class SequenceManagerConfig(ClientConfig):
def __init__(self, *args, **kwargs):
warnings.warn(
"petals.client.routing.SequenceManagerConfig has been moved to petals.ClientConfig. "
"This alias will be removed in Petals 2.2.0+",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
@dataclasses.dataclass
@ -82,7 +70,7 @@ class RemoteSequenceManager:
def __init__(
self,
config: SequenceManagerConfig,
config: ClientConfig,
block_uids: Sequence[ModuleUID],
*,
dht: Optional[DHT] = None,
@ -116,6 +104,9 @@ class RemoteSequenceManager:
self._thread_start_lock = threading.Lock()
self.policy = NoSpendingPolicy()
self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)
self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)
self.ping_aggregator = PingAggregator(dht)
if state.banned_peers is None:
@ -128,6 +119,23 @@ class RemoteSequenceManager:
self._thread.ready.set() # no need to await the first dht fetch
self._need_latest_infos = True
@staticmethod
def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
if peer_ids is None:
return None
result = set()
for peer_id in peer_ids:
if isinstance(peer_id, PeerID):
result.add(peer_id)
elif isinstance(peer_id, str):
result.add(PeerID.from_base58(peer_id))
else:
raise TypeError(
f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}"
)
return result
def make_sequence(
self,
start_index: int = 0,
@ -333,7 +341,7 @@ class RemoteSequenceManager:
def _update(self):
"""Perform an immediate and synchronous refresh, may take time"""
new_block_infos = petals.dht_utils.get_remote_module_infos(
new_block_infos = get_remote_module_infos(
self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=True
)
@ -341,13 +349,13 @@ class RemoteSequenceManager:
if not block_info:
continue
# Apply whitelist, if defined
if self.config.allowed_servers is not None:
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
}
# Apply allow and block lists
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if (self.allowed_servers is None or peer_id in self.allowed_servers)
and (self.blocked_servers is None or peer_id not in self.blocked_servers)
}
# Remove temporarily banned peers, unless there are no peers left
valid_servers = {
@ -466,14 +474,21 @@ class RemoteSequenceManager:
return 0
return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
def get_request_metadata(
self, protocol: str, args_structure: Any = None, *args, **kwargs
) -> Optional[Dict[str, Any]]:
"""
:param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
:param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging
:param args: request-specific inputs, typically block uids and input tensors
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
"""
return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter)
return dict(
points=self.policy.get_points(protocol, *args, **kwargs),
active_adapter=self.config.active_adapter,
args_structure=args_structure,
)
def shutdown(self):
self._thread.shutdown()

@ -12,10 +12,11 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs
logger = get_logger(__name__)
@ -67,15 +68,17 @@ async def sequential_forward(
span = sequences.popleft()
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
metadata = sequence_manager.get_request_metadata(
"rpc_forward", args_structure, span_uids, *flat_tensors
)
(outputs,) = await run_remote_forward(
span_uids,
stub,
sequence_manager.rpc_info,
*inputs_and_prompts,
*flat_tensors,
config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata),
)
@ -149,18 +152,21 @@ async def sequential_backward(
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
grad_outputs_cpu = [grad.cpu() for grad in grad_outputs]
flat_tensors, args_structure = pack_args_kwargs(
inputs, *grad_outputs_cpu, prompts[span.start : span.end]
)
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
metadata = sequence_manager.get_request_metadata(
"rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
"rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
)
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids,
stub,
sequence_manager.rpc_info,
inputs,
grad_outputs,
prompts[span.start : span.end],
*flat_tensors,
config=sequence_manager.config,
metadata=MSGPackSerializer.dumps(metadata),
)
@ -224,7 +230,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
if is_dummy(prompts):
if prompts is None or is_dummy(prompts):
prompt_batches = [DUMMY] * len(input_batches)
else:
prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)

@ -6,8 +6,6 @@ import pydantic
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
from petals.server.memory_cache import Handle
ModuleUID = str
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
@ -78,6 +76,8 @@ class RemoteSpanInfo:
RPCInfo = Dict[str, Any]
Handle = int
@dataclasses.dataclass(frozen=True)
class InferenceMetadata:

@ -1,124 +1,9 @@
"""
Utilities for declaring and retrieving active model layers using a shared DHT.
"""
from __future__ import annotations
import warnings
import math
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
warnings.warn(
"petals.dht_utils has been moved to petals.utils.dht. This alias will be removed in Petals 2.2.0+",
DeprecationWarning,
stacklevel=2,
)
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
logger = get_logger(__name__)
def declare_active_modules(
dht: DHT,
uids: Sequence[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
Declare that your node serves the specified modules; update timestamps if declared previously
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param throughput: specify your performance in terms of compute throughput
:param expiration_time: declared modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
if isinstance(uids, str):
uids = [uids]
if not isinstance(uids, list):
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
return_future=not wait,
)
async def _declare_active_modules(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[server_info.to_tuple()] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
def get_remote_module_infos(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
return dht.run_coroutine(
partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future,
)
async def _get_remote_module_infos(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf
elif expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
metadata = found[uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
continue
servers = {}
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules
from petals.utils.dht import *

@ -5,15 +5,15 @@ from hivemind import get_logger
from transformers.models.bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.bloom.block import WrappedBloomBlock
logger = get_logger(__name__)
class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
class DistributedBloomConfig(BloomConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedBloomBlock
attn_class = BloomAttention
block_prefix = "h"

@ -10,7 +10,7 @@ from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassifi
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.bloom.config import DistributedBloomConfig
@ -39,16 +39,15 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -59,21 +58,34 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
else:
hidden_states = self.h(hidden_states)
hidden_states = self.h(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
@ -84,7 +96,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=None,
past_key_values=RemotePastKeyValues(),
hidden_states=None,
attentions=None,
)

@ -5,15 +5,15 @@ from hivemind import get_logger
from transformers.models.llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from petals.client.config import ClientConfig
from petals.client.lm_head import LMHeadConfig
from petals.client.ptune import PTuneConfig
from petals.client.routing.sequence_manager import SequenceManagerConfig
from petals.models.llama.block import WrappedLlamaBlock
logger = get_logger(__name__)
class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfig):
block_class = WrappedLlamaBlock
attn_class = LlamaAttention
block_prefix = "model.layers"

@ -10,7 +10,7 @@ from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassifi
from petals.client.from_pretrained import FromPretrainedMixin
from petals.client.lm_head import LMHead
from petals.client.ptune import PTuneMixin
from petals.client.remote_generation import RemoteGenerationMixin
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
from petals.client.remote_sequential import RemoteSequential
from petals.models.llama.config import DistributedLlamaConfig
@ -39,16 +39,15 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BaseModelOutputWithPast:
assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
for k, v in kwargs.items():
if not (v is None or v is False):
logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
@ -59,21 +58,36 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# The causal mask will be added on the server-side
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0:
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
else:
prompts = intermediate_prompts = None
hidden_states = inputs_embeds
output_shape = input_shape + (hidden_states.size(-1),)
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
else:
hidden_states = self.layers(hidden_states)
hidden_states = self.layers(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
@ -84,7 +98,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=None,
past_key_values=RemotePastKeyValues(),
hidden_states=None,
attentions=None,
)

@ -3,21 +3,22 @@ This module implements server-side computations on served blocks: forward, backw
"""
from __future__ import annotations
from typing import AsyncIterator, Optional, Sequence, Tuple, Union
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
import torch
from hivemind.compression.serialization import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.expert_uid import ExpertUID
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger
from hivemind.utils.nested import nested_flatten
from petals.data_structures import InferenceMetadata
from petals.data_structures import Handle, InferenceMetadata
from petals.server.backend import TransformerBackend
from petals.server.memory_cache import Handle
from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import TaskPrioritizerBase
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import unpack_args_kwargs
# We prioritize short inference requests and make them use a *merged* inference pool,
# so they are processed without interruptions and extra overheads
@ -25,6 +26,8 @@ from petals.utils.misc import DUMMY, is_dummy
MAX_SHORT_INFERENCE_TOKENS = 128
MAX_NF4_SHORT_INFERENCE_TOKENS = 1
logger = get_logger(__name__)
async def run_rpc_forward(
*flat_tensors: torch.Tensor,
@ -32,6 +35,7 @@ async def run_rpc_forward(
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any = None,
) -> torch.Tensor:
"""
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
@ -41,7 +45,11 @@ async def run_rpc_forward(
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
"""
hidden_states, prompts = flat_tensors
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
hidden_states, prompts, *_ = flat_tensors
dtype = requested_backends[0].dtype
# check parse input tensors and cast dtypes
hidden_states = hidden_states.to(dtype)
@ -79,8 +87,13 @@ async def run_rpc_backward(
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
args_structure: Any = None,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
inputs, grad_outputs, prompts = flat_tensors
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
inputs, grad_outputs, prompts, *_ = flat_tensors
# Cast inputs & grad outputs to backend dtype
inputs = inputs.to(requested_backends[0].dtype)
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
@ -139,6 +152,7 @@ async def iterate_rpc_inference(
prioritizer: TaskPrioritizerBase,
points: int,
quant_type: QuantType,
args_structure: Any = None,
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
assert len(cache_handles) == len(requested_backends)
@ -146,7 +160,12 @@ async def iterate_rpc_inference(
point_per_piece = points / max_length if max_length > 0 else 0.0
async for request, step_metadata in input_iterator:
hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
hidden_states, prompts, hypo_ids, *_ = flat_tensors
batch_size, length_increment, _ = hidden_states.shape
# Cast inputs to backend dtype
@ -177,7 +196,7 @@ async def iterate_rpc_inference(
hypo_ids,
points=point_per_piece,
requested_uids=requested_uids,
type="short_inference" if can_merge_pools else "inference",
type="inference",
)
# A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.

@ -8,14 +8,17 @@ If necessary, one can rewrite this to implement a different behavior, such as:
"""
import json
import time
from contextlib import suppress
from typing import Dict, Optional, Union
import safetensors
import torch
import torch.nn as nn
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from hivemind.utils.logging import get_logger
from huggingface_hub import get_hf_file_metadata, hf_hub_url
from huggingface_hub.utils import EntryNotFoundError
from transformers import PretrainedConfig
from transformers.utils import get_file_from_repo
@ -61,7 +64,7 @@ def load_pretrained_block(
)
# dummy load, check that keys match
report = block.load_state_dict(state_dict, strict=True)
report = block.load_state_dict(state_dict, strict=False)
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
for param_name, _ in block.named_parameters():
@ -90,11 +93,14 @@ def _load_state_dict_from_repo(
if always_needs_auth(model_name) and token is None:
token = True
index_file = get_file_from_repo(
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
)
if index_file is not None: # Sharded model
with open(index_file) as f:
index_file = _find_index_file(model_name, revision=revision, token=token, cache_dir=cache_dir)
if index_file.endswith(".index.json"): # Sharded model
path = get_file_from_repo(model_name, filename=index_file, use_auth_token=token, cache_dir=cache_dir)
if path is None:
# _find_index_file() told that a file exists but we can't get it (e.g., it just disappeared)
raise ValueError(f"Failed to get file {index_file}")
with open(path) as f:
index = json.load(f)
filenames = {
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
@ -102,14 +108,15 @@ def _load_state_dict_from_repo(
if not filenames:
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
else: # Non-sharded model
filenames = {"pytorch_model.bin"}
filenames = {index_file}
logger.debug(f"Loading {block_prefix}* from {filenames}")
state_dict = {}
for filename in filenames:
shard_state_dict = _load_state_dict_from_file(
shard_state_dict = _load_state_dict_from_repo_file(
model_name,
filename,
block_prefix=block_prefix,
revision=revision,
token=token,
cache_dir=cache_dir,
@ -124,10 +131,42 @@ def _load_state_dict_from_repo(
return state_dict
def _load_state_dict_from_file(
INDEX_FILES = ["model.safetensors.index.json", "model.safetensors", "pytorch_model.bin.index.json", "pytorch_model.bin"]
def _find_index_file(
model_name: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str
) -> str:
# If we have cached weights (e.g., Pickle from older Petals versions), reuse them
for filename in INDEX_FILES:
path = get_file_from_repo(
model_name,
filename,
revision=revision,
use_auth_token=token,
cache_dir=cache_dir,
local_files_only=True,
)
if path is not None:
return filename
# If we don't, prefer Safetensors when possible
# (we don't download files here since we can't account for max_disk_space in case of large files)
for filename in INDEX_FILES:
with suppress(EntryNotFoundError):
get_hf_file_metadata(hf_hub_url(model_name, filename, revision=revision), token=token)
return filename
raise ValueError(
f"Repo {model_name} does not contain weights in a supported format: files {INDEX_FILES} do not exist"
)
def _load_state_dict_from_repo_file(
model_name: str,
filename: str,
*,
block_prefix: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: str,
@ -146,7 +185,7 @@ def _load_state_dict_from_file(
local_files_only=True,
)
if path is not None:
return torch.load(path, map_location="cpu")
return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
except Exception:
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
@ -171,7 +210,18 @@ def _load_state_dict_from_file(
)
if path is None:
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
return torch.load(path, map_location="cpu")
return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
except Exception as e:
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
time.sleep(delay)
def _load_state_dict_from_local_file(path: str, *, block_prefix: Optional[str] = None) -> StateDict:
if path.endswith(".bin"):
return torch.load(path, map_location="cpu")
if path.endswith(".safetensors"):
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys() if block_prefix is None or key.startswith(block_prefix)}
raise ValueError(f"Unknown weight format: {path}")

@ -29,10 +29,9 @@ from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming
import petals
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, Handle, ModuleUID
from petals.server.backend import TransformerBackend
from petals.server.block_functions import iterate_rpc_inference, run_rpc_backward, run_rpc_forward
from petals.server.memory_cache import Handle
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from petals.utils.convert_block import QuantType
@ -151,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
max_length = metadata.get("max_length")
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
args_structure = metadata.get("args_structure")
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(
@ -180,6 +180,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
points=points,
quant_type=self.quant_type,
args_structure=args_structure,
):
if can_push:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
@ -356,6 +357,7 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
@ -366,6 +368,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@ -383,6 +386,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
@ -393,6 +397,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
# Split the serialized_output for streaming and respond to client
@ -434,6 +439,7 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
@ -444,6 +450,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@ -459,6 +466,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
@ -469,6 +477,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)
# Split the serialized_grad_inputs for streaming and respond
for tensor in self._serialize_grads(grads, requested_backends, metadata):

@ -16,12 +16,11 @@ import hivemind
import torch
from hivemind.utils import TensorDescriptor, get_logger
from petals.data_structures import Handle
from petals.utils.asyncio import shield_and_wait
logger = get_logger(__name__)
Handle = int
class MemoryCache:
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""

@ -20,7 +20,6 @@ from transformers import PretrainedConfig
import petals
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size, resolve_block_dtype
@ -31,6 +30,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.dht import declare_active_modules, get_remote_module_infos
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
from petals.utils.version import get_compatible_model_repo

@ -14,9 +14,7 @@ class TaskPrioritizerBase(ABC):
class DummyTaskPrioritizer(TaskPrioritizerBase):
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
# Inference steps (especially short ones) go first since they are more latency-sensitive
if kwargs.get("type") == "short_inference":
return 1.0
# Inference steps go first since they are more latency-sensitive
if kwargs.get("type") == "inference":
return 2.0
return 3.0 # Forward, backward
return 1.0
return 2.0 # Forward, backward

@ -4,3 +4,4 @@ from petals.utils.auto_config import (
AutoDistributedModelForCausalLM,
AutoDistributedModelForSequenceClassification,
)
from petals.utils.dht import declare_active_modules, get_remote_module_infos

@ -0,0 +1,124 @@
"""
Utilities for declaring and retrieving active model layers using a shared DHT.
"""
from __future__ import annotations
import math
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo
logger = get_logger(__name__)
def declare_active_modules(
dht: DHT,
uids: Sequence[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
Declare that your node serves the specified modules; update timestamps if declared previously
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param throughput: specify your performance in terms of compute throughput
:param expiration_time: declared modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
if isinstance(uids, str):
uids = [uids]
if not isinstance(uids, list):
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, server_info=server_info, expiration_time=expiration_time),
return_future=not wait,
)
async def _declare_active_modules(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
server_info: ServerInfo,
expiration_time: DHTExpiration,
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[server_info.to_tuple()] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
def get_remote_module_infos(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
return dht.run_coroutine(
partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future,
)
async def _get_remote_module_infos(
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf
elif expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
metadata = found[uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
continue
servers = {}
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules

@ -1,128 +0,0 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch
TokenIds = torch.Tensor
HypoIds = torch.Tensor
class DecodingAlgorithm(ABC):
"""
An abstract class for decoding algorithms. Describes the base function of those algorithms:
they have to select new tokens and provide the corresponding hypotheses.
"""
@abstractmethod
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
:param logits: A tensor of shape (batch_size, seq_length, vocab_size)
:return: A tuple of selected token ids and corresponding hypotheses.
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
"""
pass
class GreedyAlgorithm(DecodingAlgorithm):
"""
The simplest algorithm for decoding. It selects the most probable token.
"""
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
Returns the most probable token. The second returned object is always a range of integers
from 0 to batch_size - 1.
"""
return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
class SamplingAlgorithm(DecodingAlgorithm):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
"""
:param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
:param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
:return: A tuple of selected token ids and corresponding hypotheses.
The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
"""
logits[indices_to_remove] = -float("Inf")
probs = torch.softmax(logits / self.temperature, -1)
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
return self.sample(logits, indices_to_remove)
class TopKAlgorithm(SamplingAlgorithm):
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
super().__init__(temperature=temperature)
self.top_k = top_k
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
return self.sample(logits, indices_to_remove)
class NucleusAlgorithm(SamplingAlgorithm):
def __init__(self, top_p: float, temperature: float = 1.0) -> None:
super().__init__(temperature=temperature)
self.top_p = top_p
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
probs = torch.softmax(sorted_logits / self.temperature, -1)
cumulative_probs = torch.cumsum(probs, dim=-1)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
return self.sample(logits, indices_to_remove)
class BeamSearchAlgorithm(DecodingAlgorithm):
def __init__(self, num_beams: int, batch_size: int) -> None:
self.num_beams = num_beams
self.batch_size = batch_size
self._batch_beams = [list() for _ in range(batch_size)]
def __call__(self, logits: torch.Tensor):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
probs = torch.log_softmax(sorted_logits, -1)
if len(self._batch_beams[0]) > 0:
for batch_idx in range(self.batch_size):
new_beams = []
cur_beams = self._batch_beams[batch_idx]
for beam_idx in range(len(cur_beams)):
probs_idx = batch_idx + beam_idx * self.batch_size
new_beam = cur_beams[beam_idx]
for hypo_idx in range(self.num_beams):
new_beams.append(
(new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
)
self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
else:
for batch_idx in range(self.batch_size):
for beam_idx in range(self.num_beams):
self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
return_hypos = []
return_tokens = []
for batch_idx in range(self.batch_size):
cur_beam = self._batch_beams[batch_idx]
return_hypos.append(list())
return_tokens.append(list())
for beam in cur_beam:
beam_idx = beam[1] // self.num_beams
hypo_idx = batch_idx + beam_idx * self.batch_size
token_idx = beam[1] % self.num_beams
return_hypos[-1].append(hypo_idx)
return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
return torch.tensor(return_tokens), torch.tensor(return_hypos)

@ -1,51 +0,0 @@
from abc import ABC
import torch
class ABCBloomConstraint(ABC):
"""
Base class of all kind of decoding constraints. It can be used to implement a new constraint.
"""
def __init__(self) -> None:
pass
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
"""
This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
:param tokens_id: The token id of the last chosen token.
:param logits: The logits from the Bloom model.
:param hypo_ids: The hypothesis ids of the last tokens.
"""
pass
class EosConstraint(ABCBloomConstraint):
"""
This constrained repeats EOS token if it was generated on the previous step.
Args:
prefix: The prefix of the sequence.
eos_token_id: The id of the end of sentence token.
pad_token_id: The id of the padding token.
min_logits: The minimum logits that can be generated. Default: -1e6.
"""
def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
self.eos_token_id = eos_token_id
self.min_logits = min_logits
self.past_tokens = None
self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
if self.past_tokens is not None:
mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
logits += self.min_logits * mask
logits[mask[:, 0], self.eos_token_id] = 0
if tokens_id is not None:
self.past_tokens = tokens_id
self.wait_until_starting -= 1
return logits

@ -2,6 +2,16 @@ import torch
DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters
DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
def is_dummy(tensor: torch.Tensor):
def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0
def docstring_from(source):
def add_docstring(dest):
dest.__doc__ = source.__doc__
return dest
return add_docstring

@ -0,0 +1,49 @@
from typing import Any, Dict, List, Tuple
import torch
from hivemind import nested_flatten, nested_pack
# TODO: Move functions to hivemind
def _mark_masked_tensor(index: int) -> bytes:
return b"__T" + str(index).encode()
def _is_masked_tensor(item: Any) -> bool:
return isinstance(item, bytes) and item.startswith(b"__T")
def _get_tensor_index(item: bytes) -> int:
return int(item[3:])
def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
"""
Check the function's arguments and pack all tensors into different flattened lists.
:returns: a flattened list of tensors and args and kwargs, where tensors were masked
"""
masked_flat_values, flat_tensors, tensor_to_index = [], [], {}
for value in nested_flatten((args, kwargs)):
if isinstance(value, torch.Tensor):
tensor_index = tensor_to_index.setdefault(value, len(flat_tensors))
if tensor_index == len(flat_tensors):
flat_tensors.append(value)
masked_flat_values.append(_mark_masked_tensor(tensor_index))
else:
masked_flat_values.append(value)
return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
"""
Restore arguments after `pack_args_kwargs` function.
:returns: list of args and dict of kwargs
"""
return nested_pack(
(
value if not _is_masked_tensor(value) else flat_tensors[_get_tensor_index(value)]
for value in nested_flatten(args_structure)
),
args_structure,
)

@ -3,10 +3,13 @@ import sys
import pytest
import torch
from hivemind import nested_compare, nested_flatten
from petals import AutoDistributedConfig
from petals.server.throughput import measure_compute_rps
from petals.utils.convert_block import QuantType
from petals.utils.misc import DUMMY, is_dummy
from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
from test_utils import MODEL_NAME
@ -44,3 +47,29 @@ def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: boo
inference=inference,
)
assert isinstance(compute_rps, float) and compute_rps > 0
@pytest.mark.forked
def test_pack_inputs():
x = torch.ones(3)
y = torch.arange(5)
z = DUMMY
args = (x, z, None, (y, y), z)
kwargs = dict(foo=torch.zeros(1, 1), bar={"l": "i", "g": "h", "t": ("y", "e", "a", "r", torch.rand(1), x, y)})
flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
assert len(flat_tensors) == 5
assert all(isinstance(t, torch.Tensor) for t in flat_tensors)
restored_args, restored_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
assert len(restored_args) == len(args)
assert torch.all(restored_args[0] == x).item() and restored_args[2] is None
assert nested_compare((args, kwargs), (restored_args, restored_kwargs))
for original, restored in zip(nested_flatten((args, kwargs)), nested_flatten((restored_args, restored_kwargs))):
if isinstance(original, torch.Tensor):
assert torch.all(original == restored)
else:
assert original == restored

@ -3,7 +3,6 @@ import pytest
import torch
import transformers
from hivemind import get_logger
from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin
from petals import AutoDistributedModelForCausalLM
from test_utils import *
@ -17,18 +16,29 @@ def tokenizer():
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
@pytest.fixture
def model():
return AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
@pytest.fixture
def ref_model():
return transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
@pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME,
initial_peers=INITIAL_PEERS,
torch_dtype=torch.float32,
active_adapter=ADAPTER_NAME if use_peft else None,
)
config = model.config
assert len(model.transformer.h) == model.config.num_hidden_layers
def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empty_tensors, atol=1e-3):
if use_peft:
model.config.active_adapter = ADAPTER_NAME
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
ref_model.train(False)
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
@ -42,7 +52,7 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo
recurrent_outputs = []
with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
if pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
for t in range(embs.shape[1]):
if t == 4:
@ -53,52 +63,39 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
if t == 2 and pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
recurrent_outputs = model.lm_head(recurrent_outputs)
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
logger.info("Inference is consistent with forward")
del model, embs, recurrent_outputs
if REF_NAME:
ref_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
if use_peft:
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
ref_model.train(False)
if config.vocab_size < ref_model.config.vocab_size:
ref_model.resize_token_embeddings(config.vocab_size)
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
del ref_model, ref_outputs, dummy_mask
else:
logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
assert False
assert torch.allclose(
recurrent_outputs, parallel_outputs, rtol=0, atol=atol
), "Inference differs from forward pass"
ref_outputs = ref_model.forward(test_inputs).logits.float()
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol), "Outputs are not identical to HF"
def make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls=False, **kwargs):
if not multiple_calls:
return model.generate(inputs, max_new_tokens=max_new_tokens, **kwargs)
with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess:
return torch.cat(
[
# Sessions provided both explicitly and implicitly should work
model.generate(inputs, max_new_tokens=1, **kwargs, session=sess),
model.generate(None, max_new_tokens=max_new_tokens - 2, **kwargs),
model.generate(None, max_new_tokens=1, **kwargs),
],
dim=1,
)
@pytest.mark.forked
def test_greedy_generation(tokenizer, max_new_tokens=4):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
)
hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
@ -106,85 +103,49 @@ def test_greedy_generation(tokenizer, max_new_tokens=4):
"input_ids"
]
remote_outputs_batch = model.generate(
inputs_batch,
max_new_tokens=max_new_tokens,
)
hf_outputs_batch = HfGenerationMixin.greedy_search(
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
)
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Greedy search results are not identical to HF in multibatch mode"
options = dict(max_new_tokens=max_new_tokens, do_sample=False)
for multiple_calls in [False, True]:
for inputs in [inputs_single, inputs_batch]:
outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(
outputs, ref_outputs
), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
@pytest.mark.forked
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
def test_sampling(tokenizer, sampling_options, max_new_tokens=4):
torch.manual_seed(0)
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options)
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
with torch.random.fork_rng():
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs = HfGenerationMixin.sample(
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
)
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
"input_ids"
]
with torch.random.fork_rng():
remote_outputs_batch = model.generate(
inputs_batch,
max_new_tokens=max_new_tokens,
do_sample=True,
**sampling_options,
)
with torch.random.fork_rng():
hf_outputs_batch = HfGenerationMixin.sample(
model,
input_ids=inputs_batch,
max_length=inputs_batch.size(1) + max_new_tokens,
logits_warper=logits_warper,
)
assert torch.allclose(
remote_outputs_batch, hf_outputs_batch
), "Sampling results are not identical to HF in multibatch mode"
for options in [
dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9),
dict(do_sample=True, temperature=0.5, repetition_penalty=1.2),
]:
options.update(max_new_tokens=max_new_tokens)
for multiple_calls in [False, True]:
for inputs in [inputs_single, inputs_batch]:
torch.manual_seed(0)
outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
torch.manual_seed(0)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(
outputs, ref_outputs
), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"
@pytest.mark.forked
def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
)
text = "A cat sat on a mat"
inputs = tokenizer(text, return_tensors="pt")["input_ids"]
remote_outputs = model.generate(
inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
beam_scorer = BeamSearchScorer(
batch_size=inputs.size(0),
num_beams=num_beams,
device=inputs.device,
length_penalty=0,
do_early_stopping=False,
)
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
hf_outputs = HfGenerationMixin.beam_search(
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
)
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
outputs = make_generate_calls(model, inputs, **options)
ref_outputs = ref_model.generate(inputs, **options)
assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"

@ -43,7 +43,7 @@ def test_remote_sequential():
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2)
assert torch.allclose(test_inputs.grad, full_grad, atol=3e-2)
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]

Loading…
Cancel
Save