Fix logging: do not duplicate lines, enable colors in Colab (#156)

pull/148/head^2
Alexander Borzunov 1 year ago committed by GitHub
parent 041ad20891
commit 668b736031
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1 +1,5 @@
import petals.utils.logging
__version__ = "1.0alpha1"
petals.utils.logging.initialize_logs()

@ -13,7 +13,7 @@ import time
from typing import Optional, OrderedDict, Union
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.logging import get_logger
from transformers.modeling_utils import WEIGHTS_NAME
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.utils import get_file_from_repo
@ -22,7 +22,6 @@ from petals.bloom.block import WrappedBloomBlock
from petals.server.block_utils import get_block_size
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
CLIENT_BRANCH = "main"

@ -7,13 +7,11 @@ See commit history for authorship.
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from hivemind import use_hivemind_log_handler
from hivemind import get_logger
from torch import nn
from transformers import BloomConfig
from transformers.utils import logging
use_hivemind_log_handler("in_root_logger")
logger = logging.get_logger(__file__)
logger = get_logger(__file__)
class LMHead(nn.Module):

@ -5,7 +5,7 @@ import psutil
import torch.backends.quantized
import torch.nn as nn
import transformers
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.logging import get_logger
from huggingface_hub import Repository
from tqdm.auto import tqdm
from transformers.models.bloom.modeling_bloom import BloomModel
@ -13,7 +13,6 @@ from transformers.models.bloom.modeling_bloom import BloomModel
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
from petals.client import DistributedBloomConfig
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

@ -1,14 +1,13 @@
import argparse
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.logging import get_logger
from tqdm.auto import trange
from transformers import BloomConfig
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
from petals.bloom.block import BloomBlock
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
logger.warning("inference_one_block will soon be deprecated in favour of tests!")

@ -3,13 +3,12 @@ import argparse
import configargparse
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.logging import get_logger
from humanfriendly import parse_size
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.server.server import Server
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -5,7 +5,7 @@ from typing import List, Optional
import hivemind
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger, loglevel, use_hivemind_log_handler
from hivemind.utils.logging import get_logger
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.bloom import (
BloomConfig,
@ -21,13 +21,8 @@ from petals.client.remote_sequential import RemoteSequential
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.utils.misc import DUMMY
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
# We suppress asyncio error logs by default since they are mostly not relevant for the end user
asyncio_loglevel = os.getenv("PETALS_ASYNCIO_LOGLEVEL", "FATAL" if loglevel != "DEBUG" else "DEBUG")
get_logger("asyncio").setLevel(asyncio_loglevel)
class DistributedBloomConfig(BloomConfig):
"""

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Optional, Union
import torch
from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
from hivemind import DHT, P2P, get_logger
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch import nn
@ -14,7 +14,6 @@ from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
from petals.data_structures import UID_DELIMITER
from petals.utils.misc import DUMMY
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -2,11 +2,10 @@ import dataclasses
import time
from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
from hivemind import get_logger, use_hivemind_log_handler
from hivemind import get_logger
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -14,7 +14,7 @@ from hivemind.dht.node import Blacklist
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2PHandlerError
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.logging import get_logger
import petals.dht_utils
from petals.client.routing.sequence_info import RemoteSequenceInfo
@ -22,7 +22,6 @@ from petals.client.routing.spending_policy import NoSpendingPolicy
from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
from petals.server.handler import TransformerConnectionHandler
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -10,12 +10,11 @@ from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
import petals.client
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -2,7 +2,7 @@
from typing import Any, Dict, Sequence, Tuple
import torch
from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
from hivemind import BatchTensorDescriptor
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
@ -11,7 +11,6 @@ from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -14,10 +14,8 @@ from typing import AsyncContextManager, Dict, Optional, Union
import hivemind
import torch
from hivemind import use_hivemind_log_handler
from hivemind.utils import TensorDescriptor, get_logger
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
Handle = int

@ -16,7 +16,7 @@ from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescripto
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.logging import get_logger
from transformers import BloomConfig
from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
@ -32,7 +32,6 @@ from petals.server.throughput import get_host_throughput
from petals.utils.convert_8bit import replace_8bit_linear
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -8,11 +8,10 @@ from queue import PriorityQueue
from typing import Any, List, Optional, Sequence, Tuple
import torch
from hivemind import get_logger, use_hivemind_log_handler
from hivemind import get_logger
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -8,7 +8,7 @@ from pathlib import Path
from typing import Optional, Union
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.logging import get_logger
from transformers import BloomConfig
from petals.bloom.block import WrappedBloomBlock
@ -16,7 +16,6 @@ from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_8bit import replace_8bit_linear
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -0,0 +1,34 @@
import importlib
import os
from hivemind.utils import logging as hm_logging
def in_jupyter() -> bool:
"""Check if the code is run in Jupyter or Colab"""
try:
__IPYTHON__
return True
except NameError:
return False
def initialize_logs():
"""Initialize Petals logging tweaks. This function is called when you import the `petals` module."""
# Env var PETALS_LOGGING=False prohibits Petals do anything with logs
if os.getenv("PETALS_LOGGING", "True").lower() in ("false", "0"):
return
if in_jupyter():
os.environ["HIVEMIND_COLORS"] = "True"
importlib.reload(hm_logging)
hm_logging.get_logger().handlers.clear() # Remove extra default handlers on Colab
hm_logging.use_hivemind_log_handler("in_root_logger")
# We suppress asyncio error logs by default since they are mostly not relevant for the end user,
# unless there is env var PETALS_ASYNCIO_LOGLEVEL
asyncio_loglevel = os.getenv("PETALS_ASYNCIO_LOGLEVEL", "FATAL" if hm_logging.loglevel != "DEBUG" else "DEBUG")
hm_logging.get_logger("asyncio").setLevel(asyncio_loglevel)

@ -5,10 +5,9 @@ from contextlib import suppress
import psutil
import pytest
from hivemind.utils.crypto import RSAPrivateKey
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.logging import get_logger
from hivemind.utils.mpfuture import MPFuture
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)

@ -1,14 +1,13 @@
import pytest
import torch
import transformers
from hivemind import get_logger, use_hivemind_log_handler
from hivemind import get_logger
from test_utils import *
from transformers.generation import BeamSearchScorer
from transformers.models.bloom import BloomForCausalLM
from petals.client.remote_model import DistributedBloomForCausalLM
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -1,6 +1,6 @@
import pytest
import torch
from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
from hivemind import DHT, BatchTensorDescriptor, get_logger
from hivemind.proto import runtime_pb2
from test_utils import *
@ -9,7 +9,6 @@ from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

@ -3,14 +3,13 @@ import time
import pytest
import torch
from hivemind import DHT, get_logger, use_hivemind_log_handler
from hivemind import DHT, get_logger
from test_utils import *
from petals.client import RemoteSequenceManager, RemoteSequential
from petals.client.remote_model import DistributedBloomConfig
from petals.data_structures import UID_DELIMITER
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)

Loading…
Cancel
Save