|
|
|
@ -5,13 +5,15 @@ See commit history for authorship.
|
|
|
|
|
"""
|
|
|
|
|
import os
|
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
from packaging import version
|
|
|
|
|
|
|
|
|
|
import torch.nn.quantized.dynamic.modules.linear
|
|
|
|
|
import transformers
|
|
|
|
|
from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
|
|
|
|
|
|
|
|
|
|
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
|
|
|
|
assert transformers.__version__.startswith("4.26."), "Please install transformers 4.26.1"
|
|
|
|
|
assert version.parse("4.26.0") < version.parse(transformers.__version__) < version.parse("5.0.0"), \
|
|
|
|
|
"Please install transformers >=4.26.0,<5.0.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WrappedBloomBlock(BloomBlock):
|
|
|
|
|