|
|
|
@ -1,10 +1,11 @@
|
|
|
|
|
import os
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
import hivemind
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from hivemind import get_logger, use_hivemind_log_handler
|
|
|
|
|
from hivemind.utils.logging import get_logger, loglevel, use_hivemind_log_handler
|
|
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
|
|
|
|
|
|
|
|
from petals.bloom.model import (
|
|
|
|
@ -23,6 +24,10 @@ from petals.utils.misc import DUMMY
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
# We suppress asyncio error logs by default since they are mostly not relevant for the end user
|
|
|
|
|
asyncio_loglevel = os.getenv("PETALS_ASYNCIO_LOGLEVEL", "FATAL" if loglevel != "DEBUG" else "DEBUG")
|
|
|
|
|
get_logger("asyncio").setLevel(asyncio_loglevel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistributedBloomConfig(BloomConfig):
|
|
|
|
|
"""
|
|
|
|
|