Add docstrings

pull/464/head
Aleksandr Borzunov 10 months ago
parent 293d724854
commit bacdca0f5c

@ -3,16 +3,20 @@ import dataclasses
from typing import ContextManager, List, Optional
import torch
import transformers
from hivemind.utils.logging import get_logger
from petals.client.inference_session import InferenceSession
from petals.utils.misc import DUMMY
from petals.client.remote_sequential import RemoteSequential
from petals.utils.misc import DUMMY, docstring_from
logger = get_logger(__name__)
@dataclasses.dataclass(frozen=True)
class RemotePastKeyValues:
"""A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
hypo_ids: Optional[torch.LongTensor] = None
def __getitem__(self, _index: int) -> List[torch.Tensor]:
@ -21,33 +25,32 @@ class RemotePastKeyValues:
class RemoteGenerationMixin:
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
The class exposes can be used for:
- *greedy decoding*.
- *multinomial, top-k and top-p sampling*.
- *beam-search decoding*
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
However, it has some differences for remote usage.
This class is an upgrade to `transformers.GenerationMixin` that:
- Designed to be compatible with most `transformers.GenerationMixin` strategies and options
- Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
you don't have to rerun the prefix through all the servers to generate each new token
- Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
- If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
"""
@docstring_from(RemoteSequential.active_session)
@property
def active_session(self) -> Optional[InferenceSession]:
return self.transformer.h.active_session
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
"""
Returns an inference session for the model's RemoteSequential module.
:param max_length: Maximal expected length of inference results. Servers use this parameter
to calculate the size of attention caches allocated to this client.
"""
return self.transformer.h.inference_session(**kwargs)
@docstring_from(RemoteSequential.use_session)
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
return self.transformer.h.use_session(session)
@docstring_from(RemoteSequential.inference_session)
def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
return self.transformer.h.inference_session(**kwargs)
@docstring_from(transformers.GenerationMixin.generate.__doc__)
def generate(
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
):

@ -61,15 +61,22 @@ class RemoteSequential(nn.Module):
@property
def active_session(self) -> Optional[InferenceSession]:
"""
If called inside `with model.inference_session(...):` or `with model.use_session(...):`,
returns an active InferenceSession. Otherwise, returns None.
"""
return self._thread_local.active_session
@property
def position(self) -> int:
"""Returns the prefix length (in tokens) in the active inference session or zero if no session is active."""
return self.active_session.position if self.active_session is not None else 0
@contextmanager
def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:
"""Inside this context, forward() will use the specified InferenceSession."""
"""Inside this context, forward() will use an _existing_ InferenceSession provided as the argument."""
try:
prev_session = self._thread_local.active_session
@ -80,7 +87,12 @@ class RemoteSequential(nn.Module):
@contextmanager
def inference_session(self, **kwargs) -> InferenceSession:
"""Inside this context, forward() will use a new InferenceSession created with given parameters."""
"""
Inside this context, forward() will use a _new_ InferenceSession created with given parameters.
:param max_length: Maximal expected length of inference results. Servers use this parameter
to calculate the size of attention caches allocated to this client.
"""
with InferenceSession(self.sequence_manager, **kwargs) as session, self.use_session(session):
yield session

@ -5,5 +5,13 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par
DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
def is_dummy(tensor: torch.Tensor):
def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0
def docstring_from(source):
def add_docstring(dest):
dest.__doc__ = source.__doc__
return dest
return add_docstring

Loading…
Cancel
Save