Fix formatting, reduce diff

pull/500/head
Max Ryabinin 9 months ago
parent 1fc22bd69f
commit 67764fea9e

@ -24,7 +24,6 @@ from transformers.models.falcon.modeling_falcon import (
rotate_half,
)
KVCache = Tuple[torch.Tensor, torch.Tensor]
INFERENCE_MAX_LENGTH = 8192
@ -225,6 +224,7 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
self.hidden_dropout = config.hidden_dropout
self.config = config
assert not self.config.alibi
assert config.new_decoder_architecture
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
@ -299,10 +299,6 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
class WrappedFalconBlock(OptimizedFalconDecoderLayer):
def __init__(self, config: FalconConfig):
super().__init__(config)
assert not self.config.alibi
def forward(
self,
hidden_states: torch.Tensor,

@ -1,9 +1,10 @@
import torch
from petals.models.falcon.block import UnoptimizedWrappedFalconBlock
from petals.server.block_utils import resolve_block_dtype
from petals.server.from_pretrained import load_pretrained_block
from petals.utils.auto_config import AutoDistributedConfig
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType, convert_block
import torch
def test_falcon():

Loading…
Cancel
Save