|
|
|
@ -2,13 +2,13 @@ from typing import Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
import torch
|
|
|
|
|
from transformers.cache_utils import DynamicCache
|
|
|
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
|
|
|
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
|
|
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
|
|
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
|
|
|
|
|
|
|
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
|
|
|
from petals.utils.convert_block import QuantType, convert_block
|
|
|
|
|
from transformers.cache_utils import DynamicCache
|
|
|
|
|
from test_utils import MODEL_NAME
|
|
|
|
|
|
|
|
|
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
|