|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import threading
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from contextvars import ContextVar
|
|
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
@ -13,7 +13,6 @@ from petals.client.inference_session import InferenceSession
|
|
|
|
|
from petals.client.routing import RemoteSequenceManager
|
|
|
|
|
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
|
|
|
|
|
from petals.data_structures import UID_DELIMITER
|
|
|
|
|
from petals.utils.misc import DUMMY
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
@ -48,16 +47,15 @@ class RemoteSequential(nn.Module):
|
|
|
|
|
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
|
|
|
|
|
self.sequence_manager = sequence_manager
|
|
|
|
|
|
|
|
|
|
self._thread_local = threading.local()
|
|
|
|
|
self._thread_local.active_session = None
|
|
|
|
|
self._active_session = ContextVar("active_session", default=None)
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
|
|
|
|
|
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
|
|
|
|
|
if self._thread_local.active_session is None:
|
|
|
|
|
if self.active_session is None:
|
|
|
|
|
assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
|
|
|
|
|
return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
|
|
|
|
|
else:
|
|
|
|
|
return self._thread_local.active_session.step(inputs, prompts, **kwargs)
|
|
|
|
|
return self.active_session.step(inputs, prompts, **kwargs)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def active_session(self) -> Optional[InferenceSession]:
|
|
|
|
@ -66,7 +64,7 @@ class RemoteSequential(nn.Module):
|
|
|
|
|
returns an active InferenceSession. Otherwise, returns None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
return self._thread_local.active_session
|
|
|
|
|
return self._active_session.get()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def position(self) -> int:
|
|
|
|
@ -78,12 +76,11 @@ class RemoteSequential(nn.Module):
|
|
|
|
|
def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:
|
|
|
|
|
"""Inside this context, forward() will use an _existing_ InferenceSession provided as the argument."""
|
|
|
|
|
|
|
|
|
|
token = self._active_session.set(session)
|
|
|
|
|
try:
|
|
|
|
|
prev_session = self._thread_local.active_session
|
|
|
|
|
self._thread_local.active_session = session
|
|
|
|
|
yield session
|
|
|
|
|
finally:
|
|
|
|
|
self._thread_local.active_session = prev_session
|
|
|
|
|
self._active_session.reset(token)
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def inference_session(self, **kwargs) -> InferenceSession:
|
|
|
|
|