Remove unused imports and attributes (#324)

* Remove unused imports and attributes
pull/273/merge
Max Ryabinin 11 months ago committed by GitHub
parent 675bacb592
commit 3e7ae5116d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio import asyncio
import itertools import itertools
import logging
import time import time
from typing import AsyncIterator, List, Optional from typing import AsyncIterator, List, Optional

@ -1,6 +1,5 @@
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Collection, List, Optional, Union from typing import List, Optional, Union
import hivemind import hivemind
import torch import torch

@ -4,7 +4,6 @@ from typing import Optional, Union
import torch import torch
from hivemind import DHT, get_logger from hivemind import DHT, get_logger
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch import nn from torch import nn
import petals.client import petals.client

@ -11,7 +11,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Union
from weakref import WeakMethod from weakref import WeakMethod
import numpy as np import numpy as np
from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time from hivemind import DHT, P2P, MSGPackSerializer, PeerID
from hivemind.dht.node import Blacklist from hivemind.dht.node import Blacklist
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.proto import runtime_pb2 from hivemind.proto import runtime_pb2

@ -3,7 +3,6 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
""" """
import asyncio import asyncio
import itertools import itertools
import logging
from collections import deque from collections import deque
from typing import List, Optional, Sequence, Tuple from typing import List, Optional, Sequence, Tuple

@ -8,11 +8,9 @@ from functools import partial
from typing import Dict, List, Optional, Sequence, Union from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import PeerID from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
import petals.client
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
logger = get_logger(__name__) logger = get_logger(__name__)

@ -16,7 +16,7 @@ from transformers import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.data_structures import InferenceMetadata from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import Handle, MemoryCache from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy from petals.utils.misc import is_dummy

@ -10,7 +10,7 @@ import ctypes
import multiprocessing as mp import multiprocessing as mp
import os import os
import time import time
from typing import AsyncContextManager, Dict, Optional, Sequence, Tuple from typing import AsyncContextManager, Dict, Optional, Sequence
import hivemind import hivemind
import torch import torch
@ -29,7 +29,7 @@ class MemoryCache:
def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float): def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1) self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.alloc_timeout = alloc_timeout self.alloc_timeout = alloc_timeout
self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event() self._lock_metadata = mp.Lock()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
self._allocated_tensors: Dict[Handle, torch.Tensor] = {} self._allocated_tensors: Dict[Handle, torch.Tensor] = {}

@ -5,7 +5,6 @@ import time
from concurrent.futures import Future from concurrent.futures import Future
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from secrets import token_hex
from typing import Optional from typing import Optional
import requests import requests

@ -8,7 +8,6 @@ import threading
import time import time
from typing import Dict, List, Optional, Sequence, Union from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import torch import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.layers import add_custom_models_from_file
@ -502,7 +501,6 @@ class ModuleContainer(threading.Thread):
expiration=expiration, expiration=expiration,
daemon=True, daemon=True,
) )
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
if start: if start:
self.run_in_background(await_ready=True) self.run_in_background(await_ready=True)
@ -517,9 +515,6 @@ class ModuleContainer(threading.Thread):
self.online_announcer.start() self.online_announcer.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()
for handler in self.conn_handlers: for handler in self.conn_handlers:
handler.run_in_background() handler.run_in_background()

@ -85,7 +85,6 @@ class NucleusAlgorithm(SamplingAlgorithm):
class BeamSearchAlgorithm(DecodingAlgorithm): class BeamSearchAlgorithm(DecodingAlgorithm):
def __init__(self, num_beams: int, batch_size: int) -> None: def __init__(self, num_beams: int, batch_size: int) -> None:
self.num_beams = num_beams self.num_beams = num_beams
self._cur_num_beams = 1
self.batch_size = batch_size self.batch_size = batch_size
self._batch_beams = [list() for _ in range(batch_size)] self._batch_beams = [list() for _ in range(batch_size)]

@ -1,4 +1,3 @@
import importlib
import os import os
from hivemind.utils import logging as hm_logging from hivemind.utils import logging as hm_logging

@ -8,7 +8,6 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
from petals.bloom.block import WrappedBloomBlock from petals.bloom.block import WrappedBloomBlock
from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
from petals.client import DistributedBloomConfig, RemoteSequential from petals.client import DistributedBloomConfig, RemoteSequential
from petals.data_structures import UID_DELIMITER
from test_utils import * from test_utils import *

@ -5,7 +5,6 @@ import pytest
import torch import torch
from petals.client import DistributedBloomConfig, RemoteSequential from petals.client import DistributedBloomConfig, RemoteSequential
from petals.data_structures import UID_DELIMITER
from petals.server.handler import CACHE_TOKENS_AVAILABLE from petals.server.handler import CACHE_TOKENS_AVAILABLE
from test_utils import * from test_utils import *

Loading…
Cancel
Save