Remove unused imports and attributes (#324)

* Remove unused imports and attributes
pull/273/merge
Max Ryabinin 11 months ago committed by GitHub
parent 675bacb592
commit 3e7ae5116d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
import itertools
import logging
import time
from typing import AsyncIterator, List, Optional

@ -1,6 +1,5 @@
import os
from contextlib import contextmanager
from typing import Collection, List, Optional, Union
from typing import List, Optional, Union
import hivemind
import torch

@ -4,7 +4,6 @@ from typing import Optional, Union
import torch
from hivemind import DHT, get_logger
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch import nn
import petals.client

@ -11,7 +11,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Union
from weakref import WeakMethod
import numpy as np
from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time
from hivemind import DHT, P2P, MSGPackSerializer, PeerID
from hivemind.dht.node import Blacklist
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2

@ -3,7 +3,6 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
"""
import asyncio
import itertools
import logging
from collections import deque
from typing import List, Optional, Sequence, Tuple

@ -8,11 +8,9 @@ from functools import partial
from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
import petals.client
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
logger = get_logger(__name__)

@ -16,7 +16,7 @@ from transformers import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import Handle, MemoryCache
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy

@ -10,7 +10,7 @@ import ctypes
import multiprocessing as mp
import os
import time
from typing import AsyncContextManager, Dict, Optional, Sequence, Tuple
from typing import AsyncContextManager, Dict, Optional, Sequence
import hivemind
import torch
@ -29,7 +29,7 @@ class MemoryCache:
def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.alloc_timeout = alloc_timeout
self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
self._lock_metadata = mp.Lock()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}

@ -5,7 +5,6 @@ import time
from concurrent.futures import Future
from contextlib import asynccontextmanager
from functools import partial
from secrets import token_hex
from typing import Optional
import requests

@ -8,7 +8,6 @@ import threading
import time
from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
@ -502,7 +501,6 @@ class ModuleContainer(threading.Thread):
expiration=expiration,
daemon=True,
)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
if start:
self.run_in_background(await_ready=True)
@ -517,9 +515,6 @@ class ModuleContainer(threading.Thread):
self.online_announcer.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()
for handler in self.conn_handlers:
handler.run_in_background()

@ -85,7 +85,6 @@ class NucleusAlgorithm(SamplingAlgorithm):
class BeamSearchAlgorithm(DecodingAlgorithm):
def __init__(self, num_beams: int, batch_size: int) -> None:
self.num_beams = num_beams
self._cur_num_beams = 1
self.batch_size = batch_size
self._batch_beams = [list() for _ in range(batch_size)]

@ -1,4 +1,3 @@
import importlib
import os
from hivemind.utils import logging as hm_logging

@ -8,7 +8,6 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
from petals.bloom.block import WrappedBloomBlock
from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
from petals.client import DistributedBloomConfig, RemoteSequential
from petals.data_structures import UID_DELIMITER
from test_utils import *

@ -5,7 +5,6 @@ import pytest
import torch
from petals.client import DistributedBloomConfig, RemoteSequential
from petals.data_structures import UID_DELIMITER
from petals.server.handler import CACHE_TOKENS_AVAILABLE
from test_utils import *

Loading…
Cancel
Save