@ -8,10 +8,13 @@ from typing import Optional, Tuple
import torch . nn . quantized . dynamic . modules . linear
import torch . nn . quantized . dynamic . modules . linear
import transformers
import transformers
from packaging import version
from transformers . models . bloom . modeling_bloom import BloomBlock , _expand_mask , _make_causal_mask , build_alibi_tensor
from transformers . models . bloom . modeling_bloom import BloomBlock , _expand_mask , _make_causal_mask , build_alibi_tensor
if not os . getenv ( " PETALS_IGNORE_DEPENDENCY_VERSION " ) :
if not os . getenv ( " PETALS_IGNORE_DEPENDENCY_VERSION " ) :
assert transformers . __version__ . startswith ( " 4.25. " ) , " Please install transformers 4.25.1 "
assert (
version . parse ( " 4.26.0 " ) < version . parse ( transformers . __version__ ) < version . parse ( " 5.0.0 " )
) , " Please install a proper transformers version: pip install transformers>=4.26.0,<5.0.0 "
class WrappedBloomBlock ( BloomBlock ) :
class WrappedBloomBlock ( BloomBlock ) :