|
|
@ -13,7 +13,7 @@ from transformers.models.bloom.modeling_bloom import BloomModel
|
|
|
|
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
|
|
|
|
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
|
|
|
|
from petals.client import DistributedBloomConfig
|
|
|
|
from petals.client import DistributedBloomConfig
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
|
|
|
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
|
|
|
|
|
|
|
|
|
|
|