Fix class names

pull/500/head
Max Ryabinin 9 months ago
parent ca4d091a3f
commit 1f006c59a1

@ -298,7 +298,7 @@ class OptimizedFalconDecoderLayer(FalconDecoderLayer):
return outputs # hidden_states, present, attentions
class _WrappedFalconBlock(OptimizedFalconDecoderLayer):
class WrappedFalconBlock(OptimizedFalconDecoderLayer):
def __init__(self, config: FalconConfig):
super().__init__(config)
assert not self.config.alibi
@ -379,7 +379,7 @@ class _WrappedFalconBlock(OptimizedFalconDecoderLayer):
return state
class WrappedFalconBlock(FalconDecoderLayer):
class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,

@ -1,4 +1,4 @@
from petals.models.falcon.block import WrappedFalconBlock
from petals.models.falcon.block import UnoptimizedWrappedFalconBlock
from petals.server.from_pretrained import load_pretrained_block
from petals.utils.auto_config import AutoDistributedConfig
from petals.server.block_utils import resolve_block_dtype
@ -11,7 +11,7 @@ def test_falcon():
config.alibi = False
config.new_decoder_architecture = True
device = "cuda:0"
device = "cpu"
tensor_parallel_devices = (device,)
dtype = torch.bfloat16
quant_type = QuantType.NONE
@ -19,7 +19,7 @@ def test_falcon():
block = config.block_class(config).to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
unopt_block = WrappedFalconBlock(config).to(dtype)
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
unopt_block = convert_block(
unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
)
@ -27,7 +27,7 @@ def test_falcon():
unopt_block.load_state_dict(block.state_dict())
for _ in range(3):
dummy_input = torch.randn(1, 1, config.hidden_size, device="cuda", dtype=dtype)
dummy_input = torch.randn(1, 1, config.hidden_size, device=device, dtype=dtype)
block_output = block(dummy_input)
unopt_block_output = unopt_block(dummy_input)
assert torch.allclose(block_output[0], unopt_block_output[0], atol=1e-6, rtol=0)

Loading…
Cancel
Save