From 668b736031c87679db91f4bf453b4acc95623b0e Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 15 Dec 2022 09:12:18 +0400 Subject: [PATCH] Fix logging: do not duplicate lines, enable colors in Colab (#156) --- src/petals/__init__.py | 4 +++ src/petals/bloom/from_pretrained.py | 3 +- src/petals/bloom/modeling_utils.py | 6 ++-- src/petals/cli/convert_model.py | 3 +- src/petals/cli/inference_one_block.py | 3 +- src/petals/cli/run_server.py | 3 +- src/petals/client/remote_model.py | 7 +--- src/petals/client/remote_sequential.py | 3 +- src/petals/client/routing/sequence_info.py | 3 +- src/petals/client/routing/sequence_manager.py | 3 +- src/petals/dht_utils.py | 3 +- src/petals/server/backend.py | 3 +- src/petals/server/memory_cache.py | 2 -- src/petals/server/server.py | 3 +- src/petals/server/task_pool.py | 3 +- src/petals/server/throughput.py | 3 +- src/petals/utils/logging.py | 34 +++++++++++++++++++ tests/conftest.py | 3 +- tests/test_full_model.py | 3 +- tests/test_remote_sequential.py | 3 +- tests/test_sequence_manager.py | 3 +- 21 files changed, 57 insertions(+), 44 deletions(-) create mode 100644 src/petals/utils/logging.py diff --git a/src/petals/__init__.py b/src/petals/__init__.py index e7110a9..066180b 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -1 +1,5 @@ +import petals.utils.logging + __version__ = "1.0alpha1" + +petals.utils.logging.initialize_logs() diff --git a/src/petals/bloom/from_pretrained.py b/src/petals/bloom/from_pretrained.py index b9acb6e..fa31602 100644 --- a/src/petals/bloom/from_pretrained.py +++ b/src/petals/bloom/from_pretrained.py @@ -13,7 +13,7 @@ import time from typing import Optional, OrderedDict, Union import torch -from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.logging import get_logger from transformers.modeling_utils import WEIGHTS_NAME from transformers.models.bloom.configuration_bloom import BloomConfig from transformers.utils import get_file_from_repo @@ -22,7 +22,6 @@ from petals.bloom.block import WrappedBloomBlock from petals.server.block_utils import get_block_size from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) CLIENT_BRANCH = "main" diff --git a/src/petals/bloom/modeling_utils.py b/src/petals/bloom/modeling_utils.py index 9ae60dc..6847423 100644 --- a/src/petals/bloom/modeling_utils.py +++ b/src/petals/bloom/modeling_utils.py @@ -7,13 +7,11 @@ See commit history for authorship. import torch import torch.nn.functional as F import torch.utils.checkpoint -from hivemind import use_hivemind_log_handler +from hivemind import get_logger from torch import nn from transformers import BloomConfig -from transformers.utils import logging -use_hivemind_log_handler("in_root_logger") -logger = logging.get_logger(__file__) +logger = get_logger(__file__) class LMHead(nn.Module): diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py index 2678eea..c4746fd 100644 --- a/src/petals/cli/convert_model.py +++ b/src/petals/cli/convert_model.py @@ -5,7 +5,7 @@ import psutil import torch.backends.quantized import torch.nn as nn import transformers -from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.logging import get_logger from huggingface_hub import Repository from tqdm.auto import tqdm from transformers.models.bloom.modeling_bloom import BloomModel @@ -13,7 +13,6 @@ from transformers.models.bloom.modeling_bloom import BloomModel from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH from petals.client import DistributedBloomConfig -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") diff --git a/src/petals/cli/inference_one_block.py b/src/petals/cli/inference_one_block.py index 336e2a3..9f7c5b4 100644 --- a/src/petals/cli/inference_one_block.py +++ b/src/petals/cli/inference_one_block.py @@ -1,14 +1,13 @@ import argparse import torch -from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.logging import get_logger from tqdm.auto import trange from transformers import BloomConfig from transformers.models.bloom.modeling_bloom import build_alibi_tensor from petals.bloom.block import BloomBlock -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) logger.warning("inference_one_block will soon be deprecated in favour of tests!") diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 27cca30..79c1b9d 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -3,13 +3,12 @@ import argparse import configargparse from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.limits import increase_file_limit -from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.logging import get_logger from humanfriendly import parse_size from petals.constants import PUBLIC_INITIAL_PEERS from petals.server.server import Server -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/client/remote_model.py b/src/petals/client/remote_model.py index daaef83..c45ef38 100644 --- a/src/petals/client/remote_model.py +++ b/src/petals/client/remote_model.py @@ -5,7 +5,7 @@ from typing import List, Optional import hivemind import torch import torch.nn as nn -from hivemind.utils.logging import get_logger, loglevel, use_hivemind_log_handler +from hivemind.utils.logging import get_logger from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.bloom import ( BloomConfig, @@ -21,13 +21,8 @@ from petals.client.remote_sequential import RemoteSequential from petals.constants import PUBLIC_INITIAL_PEERS from petals.utils.misc import DUMMY -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) -# We suppress asyncio error logs by default since they are mostly not relevant for the end user -asyncio_loglevel = os.getenv("PETALS_ASYNCIO_LOGLEVEL", "FATAL" if loglevel != "DEBUG" else "DEBUG") -get_logger("asyncio").setLevel(asyncio_loglevel) - class DistributedBloomConfig(BloomConfig): """ diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index aee8d67..2dc3c5b 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Optional, Union import torch -from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler +from hivemind import DHT, P2P, get_logger from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from torch import nn @@ -14,7 +14,6 @@ from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction from petals.data_structures import UID_DELIMITER from petals.utils.misc import DUMMY -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/client/routing/sequence_info.py b/src/petals/client/routing/sequence_info.py index 36eaefc..e69cd35 100644 --- a/src/petals/client/routing/sequence_info.py +++ b/src/petals/client/routing/sequence_info.py @@ -2,11 +2,10 @@ import dataclasses import time from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar -from hivemind import get_logger, use_hivemind_log_handler +from hivemind import get_logger from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 484e134..bb93158 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -14,7 +14,7 @@ from hivemind.dht.node import Blacklist from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.p2p import P2PHandlerError from hivemind.proto import runtime_pb2 -from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.logging import get_logger import petals.dht_utils from petals.client.routing.sequence_info import RemoteSequenceInfo @@ -22,7 +22,6 @@ from petals.client.routing.spending_policy import NoSpendingPolicy from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState from petals.server.handler import TransformerConnectionHandler -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 3f42f1e..09aa27a 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -10,12 +10,11 @@ from typing import Dict, List, Optional, Sequence, Union from hivemind.dht import DHT, DHTNode, DHTValue from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.p2p import PeerID -from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler +from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger import petals.client from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index f1b460d..93339b7 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Sequence, Tuple import torch -from hivemind import BatchTensorDescriptor, use_hivemind_log_handler +from hivemind import BatchTensorDescriptor from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger @@ -11,7 +11,6 @@ from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index 0410069..ac7af41 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -14,10 +14,8 @@ from typing import AsyncContextManager, Dict, Optional, Union import hivemind import torch -from hivemind import use_hivemind_log_handler from hivemind.utils import TensorDescriptor, get_logger -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) Handle = int diff --git a/src/petals/server/server.py b/src/petals/server/server.py index f509c0b..f7006cc 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -16,7 +16,7 @@ from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescripto from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.runtime import Runtime from hivemind.proto.runtime_pb2 import CompressionType -from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.logging import get_logger from transformers import BloomConfig from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block @@ -32,7 +32,6 @@ from petals.server.throughput import get_host_throughput from petals.utils.convert_8bit import replace_8bit_linear from petals.utils.disk_cache import DEFAULT_CACHE_DIR -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/server/task_pool.py b/src/petals/server/task_pool.py index 4589734..1374f94 100644 --- a/src/petals/server/task_pool.py +++ b/src/petals/server/task_pool.py @@ -8,11 +8,10 @@ from queue import PriorityQueue from typing import Any, List, Optional, Sequence, Tuple import torch -from hivemind import get_logger, use_hivemind_log_handler +from hivemind import get_logger from hivemind.moe.server.task_pool import TaskPoolBase from hivemind.utils.mpfuture import ALL_STATES, MPFuture -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 2bcd340..b408d70 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Optional, Union import torch -from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.logging import get_logger from transformers import BloomConfig from petals.bloom.block import WrappedBloomBlock @@ -16,7 +16,6 @@ from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_8bit import replace_8bit_linear from petals.utils.disk_cache import DEFAULT_CACHE_DIR -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/src/petals/utils/logging.py b/src/petals/utils/logging.py new file mode 100644 index 0000000..e4732f2 --- /dev/null +++ b/src/petals/utils/logging.py @@ -0,0 +1,34 @@ +import importlib +import os + +from hivemind.utils import logging as hm_logging + + +def in_jupyter() -> bool: + """Check if the code is run in Jupyter or Colab""" + + try: + __IPYTHON__ + return True + except NameError: + return False + + +def initialize_logs(): + """Initialize Petals logging tweaks. This function is called when you import the `petals` module.""" + + # Env var PETALS_LOGGING=False prohibits Petals do anything with logs + if os.getenv("PETALS_LOGGING", "True").lower() in ("false", "0"): + return + + if in_jupyter(): + os.environ["HIVEMIND_COLORS"] = "True" + importlib.reload(hm_logging) + + hm_logging.get_logger().handlers.clear() # Remove extra default handlers on Colab + hm_logging.use_hivemind_log_handler("in_root_logger") + + # We suppress asyncio error logs by default since they are mostly not relevant for the end user, + # unless there is env var PETALS_ASYNCIO_LOGLEVEL + asyncio_loglevel = os.getenv("PETALS_ASYNCIO_LOGLEVEL", "FATAL" if hm_logging.loglevel != "DEBUG" else "DEBUG") + hm_logging.get_logger("asyncio").setLevel(asyncio_loglevel) diff --git a/tests/conftest.py b/tests/conftest.py index 57287c3..de0b0de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,9 @@ from contextlib import suppress import psutil import pytest from hivemind.utils.crypto import RSAPrivateKey -from hivemind.utils.logging import get_logger, use_hivemind_log_handler +from hivemind.utils.logging import get_logger from hivemind.utils.mpfuture import MPFuture -use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index e3c4730..d2b272f 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -1,14 +1,13 @@ import pytest import torch import transformers -from hivemind import get_logger, use_hivemind_log_handler +from hivemind import get_logger from test_utils import * from transformers.generation import BeamSearchScorer from transformers.models.bloom import BloomForCausalLM from petals.client.remote_model import DistributedBloomForCausalLM -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 6b23a00..ed76696 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -1,6 +1,6 @@ import pytest import torch -from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler +from hivemind import DHT, BatchTensorDescriptor, get_logger from hivemind.proto import runtime_pb2 from test_utils import * @@ -9,7 +9,6 @@ from petals.client import RemoteSequenceManager, RemoteSequential from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) diff --git a/tests/test_sequence_manager.py b/tests/test_sequence_manager.py index 571c35f..69d05c4 100644 --- a/tests/test_sequence_manager.py +++ b/tests/test_sequence_manager.py @@ -3,14 +3,13 @@ import time import pytest import torch -from hivemind import DHT, get_logger, use_hivemind_log_handler +from hivemind import DHT, get_logger from test_utils import * from petals.client import RemoteSequenceManager, RemoteSequential from petals.client.remote_model import DistributedBloomConfig from petals.data_structures import UID_DELIMITER -use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__)