Use get_logger(__name__) instead of get_logger(__file__) (#265)

pull/266/head
Alexander Borzunov 1 year ago committed by GitHub
parent 55e7dc07a0
commit fee19e9b9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,7 +22,7 @@ from petals.bloom.block import WrappedBloomBlock
from petals.server.block_utils import get_block_size
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
logger = get_logger(__file__)
logger = get_logger(__name__)
CLIENT_BRANCH = "main"
BLOCK_BRANCH_PREFIX = "block_"

@ -13,7 +13,7 @@ from hivemind import get_logger
from torch import nn
from transformers import BloomConfig
logger = get_logger(__file__)
logger = get_logger(__name__)
class LMHead(nn.Module):

@ -13,7 +13,7 @@ from transformers.models.bloom.modeling_bloom import BloomModel
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
from petals.client import DistributedBloomConfig
logger = get_logger(__file__)
logger = get_logger(__name__)
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

@ -8,7 +8,7 @@ from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from petals.bloom.block import BloomBlock
logger = get_logger(__file__)
logger = get_logger(__name__)
logger.warning("inference_one_block will soon be deprecated in favour of tests!")

@ -10,7 +10,7 @@ from petals.constants import PUBLIC_INITIAL_PEERS
from petals.server.server import Server
from petals.utils.version import validate_version
logger = get_logger(__file__)
logger = get_logger(__name__)
def main():

@ -25,7 +25,7 @@ from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, R
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__)
logger = get_logger(__name__)
class _ServerInferenceSession:

@ -15,7 +15,7 @@ from petals.utils.generation_algorithms import (
)
from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
logger = get_logger(__file__)
logger = get_logger(__name__)
class RemoteGenerationMixin:

@ -21,7 +21,7 @@ from petals.client.remote_sequential import RemoteSequential
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.utils.misc import DUMMY
logger = get_logger(__file__)
logger = get_logger(__name__)
class DistributedBloomConfig(BloomConfig):

@ -14,7 +14,7 @@ from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
logger = get_logger(__file__)
logger = get_logger(__name__)
class RemoteSequential(nn.Module):

@ -6,7 +6,7 @@ from hivemind import get_logger
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
logger = get_logger(__file__)
logger = get_logger(__name__)
T = TypeVar("T")

@ -23,7 +23,7 @@ from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
logger = get_logger(__file__)
logger = get_logger(__name__)
class RemoteSequenceManager:

@ -18,7 +18,7 @@ from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__)
logger = get_logger(__name__)
MAX_TOKENS_IN_BATCH = 1024

@ -15,7 +15,7 @@ from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
import petals.client
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
logger = get_logger(__file__)
logger = get_logger(__name__)
def declare_active_modules(

@ -20,7 +20,7 @@ from petals.server.memory_cache import Handle, MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
logger = get_logger(__file__)
logger = get_logger(__name__)
class TransformerBackend(ModuleBackend):

@ -8,7 +8,7 @@ from petals.data_structures import RemoteModuleInfo, ServerState
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
logger = get_logger(__file__)
logger = get_logger(__name__)
@dataclass

@ -32,7 +32,7 @@ from petals.server.task_pool import PrioritizedTaskPool
from petals.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from petals.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__)
logger = get_logger(__name__)
CACHE_TOKENS_AVAILABLE = "cache_tokens_available"

@ -18,7 +18,7 @@ from hivemind.utils import TensorDescriptor, get_logger
from petals.utils.asyncio import shield_and_wait
logger = get_logger(__file__)
logger = get_logger(__name__)
Handle = int

@ -31,7 +31,7 @@ from petals.server.throughput import get_dtype_name, get_host_throughput
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
logger = get_logger(__file__)
logger = get_logger(__name__)
class Server:

@ -12,7 +12,7 @@ from hivemind import get_logger
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
logger = get_logger(__file__)
logger = get_logger(__name__)
@dataclass(order=True, frozen=True)

@ -16,7 +16,7 @@ from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
logger = get_logger(__file__)
logger = get_logger(__name__)
try:
import speedtest

@ -15,7 +15,7 @@ from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.bloom.block import WrappedBloomBlock
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
logger = get_logger(__name__)
def convert_block(

@ -8,7 +8,7 @@ from typing import Optional
import huggingface_hub
from hivemind.utils.logging import get_logger
logger = get_logger(__file__)
logger = get_logger(__name__)
DEFAULT_CACHE_DIR = os.getenv("PETALS_CACHE", Path(Path.home(), ".cache", "petals"))

@ -4,7 +4,7 @@ from packaging.version import parse
import petals
logger = get_logger(__file__)
logger = get_logger(__name__)
def validate_version():

@ -8,7 +8,7 @@ from transformers.models.bloom import BloomForCausalLM
from petals.client.remote_model import DistributedBloomForCausalLM
logger = get_logger(__file__)
logger = get_logger(__name__)
@pytest.mark.forked

@ -10,7 +10,7 @@ from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
logger = get_logger(__file__)
logger = get_logger(__name__)
@pytest.mark.forked

@ -10,7 +10,7 @@ from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
logger = get_logger(__file__)
logger = get_logger(__name__)
@pytest.mark.forked

Loading…
Cancel
Save