Move to .use_session()

pull/464/head
Aleksandr Borzunov 10 months ago
parent 299d0dcc87
commit 2958b3cb63

@ -276,6 +276,8 @@ class InferenceSession:
return self
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
logger.warning(f"inference_session.step: {inputs.shape=} {self.position=}")
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")

@ -1,6 +1,6 @@
import contextlib
import dataclasses
from typing import Optional
from typing import ContextManager, Optional
import torch
from hivemind.utils.logging import get_logger
@ -27,7 +27,11 @@ class RemoteGenerationMixin:
However, it has some differences for remote usage.
"""
def inference_session(self, **kwargs) -> InferenceSession:
@property
def active_session(self) -> Optional[InferenceSession]:
return self.transformer.h.active_session
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
"""
Returns an inference session for the model's RemoteSequential module.
@ -37,13 +41,16 @@ class RemoteGenerationMixin:
return self.transformer.h.inference_session(**kwargs)
def use_session(self, session: InferenceSession) -> ContextManager[InferenceSession]:
return self.transformer.h.use_session(session)
def generate(self, *args, session: Optional[InferenceSession] = None, **kwargs):
if session is None:
context_manager = self.inference_session(max_length=2048) # FIXME: Provide actual length
else:
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
with context_manager as session:
return super().generate(*args, session=session, **kwargs)
return super().generate(*args, **kwargs)
@staticmethod
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:

@ -1,5 +1,7 @@
from __future__ import annotations
import threading
from contextlib import contextmanager
from typing import Optional, Union
import torch
@ -46,11 +48,38 @@ class RemoteSequential(nn.Module):
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
self.sequence_manager = sequence_manager
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None):
self._thread_local = threading.local()
self._thread_local.active_session = None
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version"
outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
return outputs
if self._thread_local.active_session is None:
assert any(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
else:
return self._thread_local.active_session.step(inputs, prompts, **kwargs)
@property
def active_session(self) -> Optional[InferenceSession]:
return self._thread_local.active_session
@contextmanager
def use_session(self, session: InferenceSession) -> InferenceSession:
""" Inside this context, forward() will use the specified InferenceSession. """
try:
prev_session = self._thread_local.active_session
self._thread_local.active_session = session
yield session
finally:
self._thread_local.active_session = prev_session
@contextmanager
def inference_session(self, **kwargs) -> InferenceSession:
""" Inside this context, forward() will use a new InferenceSession created with given parameters. """
with self.use_session(InferenceSession(self.sequence_manager, **kwargs)) as session:
yield session
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
return RemoteSequential(
@ -65,8 +94,5 @@ class RemoteSequential(nn.Module):
def __len__(self):
return len(self.sequence_manager)
def inference_session(self, **kwargs) -> InferenceSession:
return InferenceSession(self.sequence_manager, **kwargs)
def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

@ -40,19 +40,23 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[RemotePastKeyValues] = None,
session: Optional[InferenceSession] = None,
**kwargs,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BaseModelOutputWithPast:
# FIXME: Assert that the mask is None or triangle
# FIXME: Assert that the mask is None or triangle and position_ids are valid
# assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
logger.warning(f"forward: {input_ids.shape=} {session=} {kwargs=}")
logger.warning(f"forward: {input_ids=} {self.layers.active_session=}")
for k, v in kwargs.items():
if not (v is None or v is False):
logger.warning(f"Extra keyword arguments are not yet supported (got {k} = {v})")
assert use_cache is None or use_cache, "use_cache=False is not supported"
assert not output_attentions, "output_attentions=True is not supported"
assert not output_hidden_states, "output_hidden_states=True is not supported"
assert return_dict is None or return_dict, "return_dict=True is not supported"
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -70,7 +74,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
if (
self.config.tuning_mode
and "ptune" in self.config.tuning_mode
and (session is None or session.position == 0)
and (self.layers.active_session is None or self.layers.active_session.position == 0)
):
batch_size = inputs_embeds.shape[0]
prompts, intermediate_prompts = self.get_prompt(batch_size)
@ -81,14 +85,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
hidden_states = inputs_embeds
output_shape = input_shape + (hidden_states.size(-1),)
if session is not None:
hidden_states = session.step(
hidden_states,
prompts=intermediate_prompts,
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
)
else:
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
hidden_states = self.layers(hidden_states, prompts=intermediate_prompts, hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None)
# Remove prefix
if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
@ -137,74 +134,6 @@ class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Ll
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
session: Optional[InferenceSession] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
session=session,
)
hidden_states = outputs[0]
if self.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def get_output_embeddings(self):
return self.lm_head

@ -101,9 +101,10 @@ def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
]
for inputs in [inputs_single, inputs_batch]:
outputs = model.generate(inputs, max_new_tokens=max_new_tokens)
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens)
assert torch.allclose(outputs, ref_outputs), f"Greedy search is not identical to HF with {inputs.shape=}"
logger.warning(f"test_greedy_generation: {inputs=}")
outputs = model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False)
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False)
assert torch.allclose(outputs, ref_outputs), f"Greedy generation is not identical to HF with {inputs.shape=}"
@pytest.mark.forked

Loading…
Cancel
Save