Added primitives for speculative decoding and tests (#598)

This PR creates a DistributedLlamaModelForSpeculativeGeneration that implements basic speculative decoding (currently for greedy inference only).
borzunov-patch-3
Anton Sinitsin 2 months ago committed by GitHub
parent a2d4b65ae0
commit 02bbd85ed8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -83,6 +83,17 @@ class _ServerInferenceSession:
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"
@property
def position(self):
return self._position
@position.setter
def position(self, start_from_position: int):
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
def step(
self,
inputs: torch.Tensor,
@ -90,7 +101,6 @@ class _ServerInferenceSession:
hypo_ids: torch.LongTensor,
*,
step_id: str,
start_from_position: int,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
@ -100,12 +110,6 @@ class _ServerInferenceSession:
if self.closed:
raise Exception("Session is closed, cannot perform step")
if start_from_position is not None:
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
n_input_tokens = inputs.shape[1]
if self.history is None:
self.history = inputs
@ -127,8 +131,8 @@ class _ServerInferenceSession:
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
if start_from_position is not None:
request_metadata["start_from_position"] = start_from_position
if self._position is not None:
request_metadata["start_from_position"] = self._position
elif self.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
@ -235,6 +239,13 @@ class InferenceSession:
def position(self) -> int:
return self._position
@position.setter
def position(self, start_from_position: int) -> None:
self._position = start_from_position
for session in self._server_sessions:
assert isinstance(session, _ServerInferenceSession)
session.position = start_from_position
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
server_sessions = []
try:
@ -275,12 +286,7 @@ class InferenceSession:
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
start_from_position: Optional[int] = None,
) -> torch.Tensor:
if start_from_position is not None:
self._position = start_from_position
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@ -324,12 +330,12 @@ class InferenceSession:
self._update_sequence(server_idx, block_idx, attempt_no)
server_session = self._server_sessions[server_idx]
assert server_session.position == self.position, f"{server_session.position} and {self.position}"
inputs = server_session.step(
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids,
step_id=step_id,
start_from_position=start_from_position,
)
server_idx += 1

@ -5,11 +5,13 @@ from petals.models.llama.model import (
DistributedLlamaForSequenceClassification,
DistributedLlamaModel,
)
from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration
from petals.utils.auto_config import register_model_classes
register_model_classes(
config=DistributedLlamaConfig,
model=DistributedLlamaModel,
model_for_causal_lm=DistributedLlamaForCausalLM,
model_for_speculative=DistributedLlamaForSpeculativeGeneration,
model_for_sequence_classification=DistributedLlamaForSequenceClassification,
)

@ -0,0 +1,111 @@
from typing import Optional, Union
import torch
from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama import LlamaForCausalLM
from petals.models.llama.config import DistributedLlamaConfig
from petals.models.llama.model import DistributedLlamaForCausalLM
class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):
DistributedLlamaForCausalLM.__init__(self, config)
self.small_model = small_model
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
speculative_inference_iteration_size: int = 10,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
assert not generation_config.do_sample, "sample is not working for speculative generation now"
assert not synced_gpus, "synced_gpus is not working for speculative generation now"
assert (
not generation_config.return_dict_in_generate
), "return_dict_in_generate is not working for speculative generation now"
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
finished = False
firsts = True
while not finished:
speculative_inference_iteration_size = min(
speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1]
)
with torch.no_grad():
speculative_outputs = self.small_model.generate(
input_ids,
max_new_tokens=speculative_inference_iteration_size,
do_sample=False,
)
speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:]
full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1]
input_for_validation = full_sequence
if not firsts:
self.active_session.position = input_ids.shape[1] - 1
input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :]
else:
firsts = False
input_for_validation = input_for_validation[:, :-1]
with torch.no_grad():
precise_model_outputs = self(input_for_validation)
full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone()
all_valid_tokens = []
first_token = None
for i in range(speculative_inference_iteration_size):
token_logits = full_token_logits[:, i, :]
token_scores = logits_processor(
input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits
)
valid_token = torch.argmax(token_scores, dim=-1)
if first_token is None:
first_token = valid_token
if valid_token.item() == speculative_tokens[:, i].item():
all_valid_tokens.append(valid_token.unsqueeze(-1))
else:
break
if not all_valid_tokens and first_token is not None:
all_valid_tokens.append(first_token.unsqueeze(-1))
all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * (
1 - unfinished_sequences
)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)
if streamer is not None:
streamer.put(all_valid_tokens.cpu())
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
finished = unfinished_sequences.max() == 0
del precise_model_outputs
if streamer is not None:
streamer.end()
return input_ids

@ -3,5 +3,6 @@ from petals.utils.auto_config import (
AutoDistributedModel,
AutoDistributedModelForCausalLM,
AutoDistributedModelForSequenceClassification,
AutoDistributedSpeculativeModel,
)
from petals.utils.dht import declare_active_modules, get_remote_module_infos

@ -15,6 +15,7 @@ class _ModelClasses:
config: Type[PretrainedConfig]
model: Optional[Type[PreTrainedModel]] = None
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
model_for_speculative: Optional[Type[PreTrainedModel]] = None
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None
@ -90,5 +91,9 @@ class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase
_mapping_field = "model_for_causal_lm"
class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_speculative"
class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_sequence_classification"

@ -2,8 +2,14 @@ import random
import pytest
import torch
import transformers
from petals import AutoDistributedConfig, RemoteSequential
from petals import (
AutoDistributedConfig,
AutoDistributedSpeculativeModel,
DistributedLlamaForSpeculativeGeneration,
RemoteSequential,
)
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
@ -26,10 +32,54 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
initial_outputs_inference = sess.step(inputs)
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
sess.position = 2
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(short_inputs)
assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
@pytest.fixture
def noisy_model():
noisy_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
lm_head = noisy_model.get_output_embeddings()
assert isinstance(lm_head, torch.nn.Linear)
with torch.no_grad():
lm_head.weight += torch.randn_like(lm_head.weight) * 0.02
return noisy_model
@pytest.fixture
def model():
return transformers.AutoModelForCausalLM.from_pretrained(
MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
@pytest.fixture
def tokenizer():
# We set use_fast=False since LlamaTokenizerFast is slow on load
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
@pytest.mark.forked
@pytest.mark.skipif(
"llama" not in MODEL_NAME.lower(),
reason="Speculative generation now works only for llama models",
)
def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3):
speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model
)
inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False)
generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False)
assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)

Loading…
Cancel
Save