threading.local -> contextvars.ContextVar

pull/464/head
Aleksandr Borzunov 10 months ago
parent d19bac4962
commit 6657e1e901

@ -3,7 +3,7 @@ import json
import os
import re
import tempfile
import threading
from contextvars import ContextVar
from typing import List, Optional, Tuple, Union
import torch
@ -47,18 +47,16 @@ class FromPretrainedMixin:
)
_shard_config = threading.local()
_shard_config.ignored_keys = None
_ignored_keys = ContextVar("ignored_keys", default=None)
@contextlib.contextmanager
def ignore_keys(patterns: List[str]):
token = _ignored_keys.set(patterns)
try:
prev_patterns = _shard_config.ignored_keys
_shard_config.ignored_keys = patterns
yield
finally:
_shard_config.ignored_keys = prev_patterns
_ignored_keys.reset(token)
def patched_get_checkpoint_shard_files(
@ -66,7 +64,7 @@ def patched_get_checkpoint_shard_files(
) -> Tuple[List[str], dict]:
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
should_ignore_keys = _shard_config.ignored_keys is not None
should_ignore_keys = _ignored_keys.get() is not None
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
with tempdir_ctx as tempdir:
if should_ignore_keys:
@ -77,7 +75,7 @@ def patched_get_checkpoint_shard_files(
index["weight_map"] = {
param_name: filename
for param_name, filename in index["weight_map"].items()
if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get())
}
n_loaded_shards = len(set(index["weight_map"].values()))
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")

@ -76,9 +76,9 @@ def force_non_empty_weights():
[1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
"""
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
try:
possibly_patched_register_parameter = nn.Module.register_parameter
nn.Module.register_parameter = _original_register_parameter
yield
finally:
nn.Module.register_parameter = possibly_patched_register_parameter

@ -1,7 +1,7 @@
from __future__ import annotations
import threading
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Optional, Union
import torch
@ -13,7 +13,6 @@ from petals.client.inference_session import InferenceSession
from petals.client.routing import RemoteSequenceManager
from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
logger = get_logger(__name__)
@ -48,16 +47,15 @@ class RemoteSequential(nn.Module):
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
self.sequence_manager = sequence_manager
self._thread_local = threading.local()
self._thread_local.active_session = None
self._active_session = ContextVar("active_session", default=None)
def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
if self._thread_local.active_session is None:
if self.active_session is None:
assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
else:
return self._thread_local.active_session.step(inputs, prompts, **kwargs)
return self.active_session.step(inputs, prompts, **kwargs)
@property
def active_session(self) -> Optional[InferenceSession]:
@ -66,7 +64,7 @@ class RemoteSequential(nn.Module):
returns an active InferenceSession. Otherwise, returns None.
"""
return self._thread_local.active_session
return self._active_session.get()
@property
def position(self) -> int:
@ -78,12 +76,11 @@ class RemoteSequential(nn.Module):
def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:
"""Inside this context, forward() will use an _existing_ InferenceSession provided as the argument."""
token = self._active_session.set(session)
try:
prev_session = self._thread_local.active_session
self._thread_local.active_session = session
yield session
finally:
self._thread_local.active_session = prev_session
self._active_session.reset(token)
@contextmanager
def inference_session(self, **kwargs) -> InferenceSession:

Loading…
Cancel
Save