|
|
|
@ -3,16 +3,20 @@ import dataclasses
|
|
|
|
|
from typing import ContextManager, List, Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import transformers
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
|
|
|
|
|
from petals.client.inference_session import InferenceSession
|
|
|
|
|
from petals.utils.misc import DUMMY
|
|
|
|
|
from petals.client.remote_sequential import RemoteSequential
|
|
|
|
|
from petals.utils.misc import DUMMY, docstring_from
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
|
|
|
class RemotePastKeyValues:
|
|
|
|
|
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
|
|
|
|
|
|
|
|
|
|
hypo_ids: Optional[torch.LongTensor] = None
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, _index: int) -> List[torch.Tensor]:
|
|
|
|
@ -21,33 +25,32 @@ class RemotePastKeyValues:
|
|
|
|
|
|
|
|
|
|
class RemoteGenerationMixin:
|
|
|
|
|
"""
|
|
|
|
|
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
|
|
|
|
|
The class exposes can be used for:
|
|
|
|
|
- *greedy decoding*.
|
|
|
|
|
- *multinomial, top-k and top-p sampling*.
|
|
|
|
|
- *beam-search decoding*
|
|
|
|
|
|
|
|
|
|
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
|
|
|
|
|
However, it has some differences for remote usage.
|
|
|
|
|
This class is an upgrade to `transformers.GenerationMixin` that:
|
|
|
|
|
|
|
|
|
|
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
|
|
|
|
|
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
|
|
|
|
|
you don't have to rerun the prefix through all the servers to generate each new token
|
|
|
|
|
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
|
|
|
|
|
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
|
|
|
|
|
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
|
|
|
|
|
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
|
|
|
|
|
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@docstring_from(RemoteSequential.active_session)
|
|
|
|
|
@property
|
|
|
|
|
def active_session(self) -> Optional[InferenceSession]:
|
|
|
|
|
return self.transformer.h.active_session
|
|
|
|
|
|
|
|
|
|
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
|
|
|
|
|
"""
|
|
|
|
|
Returns an inference session for the model's RemoteSequential module.
|
|
|
|
|
|
|
|
|
|
:param max_length: Maximal expected length of inference results. Servers use this parameter
|
|
|
|
|
to calculate the size of attention caches allocated to this client.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
return self.transformer.h.inference_session(**kwargs)
|
|
|
|
|
|
|
|
|
|
@docstring_from(RemoteSequential.use_session)
|
|
|
|
|
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
|
|
|
|
|
return self.transformer.h.use_session(session)
|
|
|
|
|
|
|
|
|
|
@docstring_from(RemoteSequential.inference_session)
|
|
|
|
|
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
|
|
|
|
|
return self.transformer.h.inference_session(**kwargs)
|
|
|
|
|
|
|
|
|
|
@docstring_from(transformers.GenerationMixin.generate.__doc__)
|
|
|
|
|
def generate(
|
|
|
|
|
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
|
|
|
|
|
):
|
|
|
|
|