|
|
|
@ -1,20 +1,18 @@
|
|
|
|
|
# this code is in active development, interfaces may change
|
|
|
|
|
import os
|
|
|
|
|
from typing import Optional, Union, Tuple
|
|
|
|
|
from typing import Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import hivemind
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import DHT, get_logger, use_hivemind_log_handler
|
|
|
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|
|
|
|
|
|
|
|
|
from src.bloom import BloomForYou, DistributedBloomConfig
|
|
|
|
|
from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
|
|
|
|
|
from src.client.remote_sequential import RemoteSequential
|
|
|
|
|
from src.data_structures import UID_DELIMITER
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import use_hivemind_log_handler
|
|
|
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|