mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Update transformers to 4.31.0 and peft to 0.4.0 (#371)
This commit is contained in:
parent
1ab35c2826
commit
c735dd7ba3
2
.github/workflows/run-tests.yaml
vendored
2
.github/workflows/run-tests.yaml
vendored
@ -10,7 +10,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [ '3.7', '3.8', '3.9', '3.10' ]
|
python-version: [ '3.8', '3.9', '3.10' ]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
timeout-minutes: 15
|
timeout-minutes: 15
|
||||||
steps:
|
steps:
|
||||||
|
@ -31,7 +31,7 @@ print(tokenizer.decode(outputs[0])) # A cat sat on a mat...
|
|||||||
|
|
||||||
### Connect your GPU and increase Petals capacity
|
### Connect your GPU and increase Petals capacity
|
||||||
|
|
||||||
Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.7+):
|
Run these commands in an [Anaconda](https://www.anaconda.com) env (requires Linux and Python 3.8+):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
|
conda install pytorch pytorch-cuda=11.7 -c pytorch -c nvidia
|
||||||
|
@ -15,9 +15,9 @@ classifiers =
|
|||||||
Intended Audience :: Science/Research
|
Intended Audience :: Science/Research
|
||||||
License :: OSI Approved :: MIT License
|
License :: OSI Approved :: MIT License
|
||||||
Programming Language :: Python :: 3
|
Programming Language :: Python :: 3
|
||||||
Programming Language :: Python :: 3.7
|
|
||||||
Programming Language :: Python :: 3.8
|
Programming Language :: Python :: 3.8
|
||||||
Programming Language :: Python :: 3.9
|
Programming Language :: Python :: 3.9
|
||||||
|
Programming Language :: Python :: 3.10
|
||||||
Topic :: Scientific/Engineering
|
Topic :: Scientific/Engineering
|
||||||
Topic :: Scientific/Engineering :: Mathematics
|
Topic :: Scientific/Engineering :: Mathematics
|
||||||
Topic :: Scientific/Engineering :: Artificial Intelligence
|
Topic :: Scientific/Engineering :: Artificial Intelligence
|
||||||
@ -29,14 +29,14 @@ classifiers =
|
|||||||
package_dir =
|
package_dir =
|
||||||
= src
|
= src
|
||||||
packages = find:
|
packages = find:
|
||||||
python_requires = >=3.7
|
python_requires = >=3.8
|
||||||
install_requires =
|
install_requires =
|
||||||
torch>=1.12
|
torch>=1.12
|
||||||
bitsandbytes==0.40.1.post1
|
bitsandbytes==0.40.1.post1
|
||||||
accelerate>=0.16.0,<0.21.0
|
accelerate>=0.16.0,<0.21.0
|
||||||
huggingface-hub>=0.11.1,<1.0.0
|
huggingface-hub>=0.11.1,<1.0.0
|
||||||
tokenizers>=0.13.3
|
tokenizers>=0.13.3
|
||||||
transformers>=4.30.1,<4.31.0
|
transformers>=4.31.0,<5.0.0
|
||||||
speedtest-cli==2.1.3
|
speedtest-cli==2.1.3
|
||||||
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8
|
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind==1.1.8
|
||||||
hivemind==1.1.8
|
hivemind==1.1.8
|
||||||
@ -46,7 +46,7 @@ install_requires =
|
|||||||
cpufeature>=0.2.0
|
cpufeature>=0.2.0
|
||||||
packaging>=20.9
|
packaging>=20.9
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
|
peft>=0.4.0
|
||||||
safetensors>=0.3.1
|
safetensors>=0.3.1
|
||||||
Dijkstar>=2.6.0
|
Dijkstar>=2.6.0
|
||||||
|
|
||||||
|
@ -16,8 +16,8 @@ __version__ = "1.2.0.dev3"
|
|||||||
|
|
||||||
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
||||||
assert (
|
assert (
|
||||||
version.parse("4.30.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
|
version.parse("4.31.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
|
||||||
), "Please install a proper transformers version: pip install transformers>=4.30.1,<5.0.0"
|
), "Please install a proper transformers version: pip install transformers>=4.31.0,<5.0.0"
|
||||||
|
|
||||||
|
|
||||||
def _override_bfloat16_mode_default():
|
def _override_bfloat16_mode_default():
|
||||||
|
@ -132,7 +132,7 @@ def main():
|
|||||||
parser.add_argument("--mean_balance_check_period", type=float, default=60,
|
parser.add_argument("--mean_balance_check_period", type=float, default=60,
|
||||||
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
|
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
|
||||||
|
|
||||||
parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
|
parser.add_argument("--token", action='store_true', help="Hugging Face hub auth token for .from_pretrained()")
|
||||||
parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
|
parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
|
||||||
help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
|
help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
|
||||||
"4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
|
"4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
|
||||||
|
@ -20,9 +20,7 @@ logger = get_logger(__name__)
|
|||||||
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
||||||
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
"""BloomModel, but all transformer layers are hosted by the swarm"""
|
||||||
|
|
||||||
_keys_to_ignore_on_load_missing = (
|
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||||
BloomModel._keys_to_ignore_on_load_missing + PTuneMixin._keys_to_ignore_on_load_missing
|
|
||||||
)
|
|
||||||
_keys_to_ignore_on_load_unexpected = [r"^h\."]
|
_keys_to_ignore_on_load_unexpected = [r"^h\."]
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
config_class = DistributedBloomConfig
|
||||||
@ -93,11 +91,8 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
|
|||||||
|
|
||||||
|
|
||||||
class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
|
class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
|
||||||
_keys_to_ignore_on_load_missing = (
|
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||||
BloomForCausalLM._keys_to_ignore_on_load_missing
|
_keys_to_ignore_on_load_missing += [r"^lm_head\."] # Missing since they are shared with input embeddings
|
||||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
||||||
+ [r"^lm_head\."] # Missing since they are shared with input embeddings
|
|
||||||
)
|
|
||||||
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
config_class = DistributedBloomConfig
|
||||||
@ -115,10 +110,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
|
|||||||
|
|
||||||
|
|
||||||
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
|
class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
|
||||||
_keys_to_ignore_on_load_missing = (
|
_keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
|
||||||
BloomForSequenceClassification._keys_to_ignore_on_load_missing
|
|
||||||
+ DistributedBloomModel._keys_to_ignore_on_load_missing
|
|
||||||
)
|
|
||||||
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
_keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
config_class = DistributedBloomConfig
|
config_class = DistributedBloomConfig
|
||||||
|
@ -21,7 +21,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
|
|||||||
"""LlamaModel, but all transformer layers are hosted by the swarm"""
|
"""LlamaModel, but all transformer layers are hosted by the swarm"""
|
||||||
|
|
||||||
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
|
||||||
_keys_to_ignore_on_load_unexpected = LlamaModel._keys_to_ignore_on_load_unexpected + [r"^model\.layers\."]
|
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
|
||||||
|
|
||||||
config_class = DistributedLlamaConfig
|
config_class = DistributedLlamaConfig
|
||||||
|
|
||||||
@ -115,6 +115,8 @@ class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Ll
|
|||||||
def __init__(self, config: DistributedLlamaConfig):
|
def __init__(self, config: DistributedLlamaConfig):
|
||||||
LlamaPreTrainedModel.__init__(self, config)
|
LlamaPreTrainedModel.__init__(self, config)
|
||||||
self.model = DistributedLlamaModel(config)
|
self.model = DistributedLlamaModel(config)
|
||||||
|
self.pretraining_tp = config.pretraining_tp
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
self.lm_head = LMHead(config)
|
self.lm_head = LMHead(config)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
@ -129,10 +131,7 @@ class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Ll
|
|||||||
|
|
||||||
|
|
||||||
class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
|
class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
|
||||||
_keys_to_ignore_on_load_missing = (
|
_keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
|
||||||
LlamaForSequenceClassification._keys_to_ignore_on_load_missing
|
|
||||||
+ DistributedLlamaModel._keys_to_ignore_on_load_missing
|
|
||||||
)
|
|
||||||
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
_keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
|
||||||
|
|
||||||
config_class = DistributedLlamaConfig
|
config_class = DistributedLlamaConfig
|
||||||
|
@ -34,12 +34,12 @@ def load_pretrained_block(
|
|||||||
config: Optional[PretrainedConfig] = None,
|
config: Optional[PretrainedConfig] = None,
|
||||||
torch_dtype: Union[torch.dtype, str] = "auto",
|
torch_dtype: Union[torch.dtype, str] = "auto",
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
use_auth_token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
max_disk_space: Optional[int] = None,
|
max_disk_space: Optional[int] = None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
if config is None:
|
if config is None:
|
||||||
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token)
|
config = AutoDistributedConfig.from_pretrained(model_name, token=token)
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = DEFAULT_CACHE_DIR
|
cache_dir = DEFAULT_CACHE_DIR
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ def load_pretrained_block(
|
|||||||
model_name,
|
model_name,
|
||||||
block_prefix,
|
block_prefix,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
use_auth_token=use_auth_token,
|
token=token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
max_disk_space=max_disk_space,
|
max_disk_space=max_disk_space,
|
||||||
)
|
)
|
||||||
@ -82,12 +82,12 @@ def _load_state_dict_from_repo(
|
|||||||
block_prefix: str,
|
block_prefix: str,
|
||||||
*,
|
*,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
use_auth_token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
cache_dir: str,
|
cache_dir: str,
|
||||||
max_disk_space: Optional[int] = None,
|
max_disk_space: Optional[int] = None,
|
||||||
) -> StateDict:
|
) -> StateDict:
|
||||||
index_file = get_file_from_repo(
|
index_file = get_file_from_repo(
|
||||||
model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir
|
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
if index_file is not None: # Sharded model
|
if index_file is not None: # Sharded model
|
||||||
with open(index_file) as f:
|
with open(index_file) as f:
|
||||||
@ -107,7 +107,7 @@ def _load_state_dict_from_repo(
|
|||||||
model_name,
|
model_name,
|
||||||
filename,
|
filename,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
use_auth_token=use_auth_token,
|
token=token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
max_disk_space=max_disk_space,
|
max_disk_space=max_disk_space,
|
||||||
)
|
)
|
||||||
@ -125,7 +125,7 @@ def _load_state_dict_from_file(
|
|||||||
filename: str,
|
filename: str,
|
||||||
*,
|
*,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
use_auth_token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
cache_dir: str,
|
cache_dir: str,
|
||||||
max_disk_space: Optional[int] = None,
|
max_disk_space: Optional[int] = None,
|
||||||
delay: float = 30,
|
delay: float = 30,
|
||||||
@ -137,7 +137,7 @@ def _load_state_dict_from_file(
|
|||||||
model_name,
|
model_name,
|
||||||
filename,
|
filename,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
)
|
)
|
||||||
@ -151,7 +151,7 @@ def _load_state_dict_from_file(
|
|||||||
try:
|
try:
|
||||||
with allow_cache_writes(cache_dir):
|
with allow_cache_writes(cache_dir):
|
||||||
url = hf_hub_url(model_name, filename, revision=revision)
|
url = hf_hub_url(model_name, filename, revision=revision)
|
||||||
file_size = get_hf_file_metadata(url, token=use_auth_token).size
|
file_size = get_hf_file_metadata(url, token=token).size
|
||||||
if file_size is not None:
|
if file_size is not None:
|
||||||
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||||
else:
|
else:
|
||||||
@ -161,7 +161,7 @@ def _load_state_dict_from_file(
|
|||||||
model_name,
|
model_name,
|
||||||
filename,
|
filename,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
)
|
)
|
||||||
|
@ -77,7 +77,7 @@ class Server:
|
|||||||
balance_quality: float = 0.75,
|
balance_quality: float = 0.75,
|
||||||
mean_balance_check_period: float = 120,
|
mean_balance_check_period: float = 120,
|
||||||
mean_block_selection_delay: float = 2.5,
|
mean_block_selection_delay: float = 2.5,
|
||||||
use_auth_token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
quant_type: Optional[QuantType] = None,
|
quant_type: Optional[QuantType] = None,
|
||||||
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
|
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
|
||||||
skip_reachability_check: bool = False,
|
skip_reachability_check: bool = False,
|
||||||
@ -98,14 +98,14 @@ class Server:
|
|||||||
self.compression = compression
|
self.compression = compression
|
||||||
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
||||||
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
||||||
self.revision, self.use_auth_token = revision, use_auth_token
|
self.revision, self.token = revision, token
|
||||||
|
|
||||||
if custom_module_path is not None:
|
if custom_module_path is not None:
|
||||||
add_custom_models_from_file(custom_module_path)
|
add_custom_models_from_file(custom_module_path)
|
||||||
|
|
||||||
self.block_config = AutoDistributedConfig.from_pretrained(
|
self.block_config = AutoDistributedConfig.from_pretrained(
|
||||||
converted_model_name_or_path,
|
converted_model_name_or_path,
|
||||||
use_auth_token=use_auth_token,
|
token=token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -271,7 +271,7 @@ class Server:
|
|||||||
self.block_config,
|
self.block_config,
|
||||||
self.torch_dtype,
|
self.torch_dtype,
|
||||||
self.adapters,
|
self.adapters,
|
||||||
use_auth_token=self.use_auth_token,
|
token=self.token,
|
||||||
cache_dir=self.cache_dir,
|
cache_dir=self.cache_dir,
|
||||||
max_disk_space=self.max_disk_space,
|
max_disk_space=self.max_disk_space,
|
||||||
)
|
)
|
||||||
@ -316,7 +316,7 @@ class Server:
|
|||||||
prefetch_batches=self.prefetch_batches,
|
prefetch_batches=self.prefetch_batches,
|
||||||
sender_threads=self.sender_threads,
|
sender_threads=self.sender_threads,
|
||||||
revision=self.revision,
|
revision=self.revision,
|
||||||
use_auth_token=self.use_auth_token,
|
token=self.token,
|
||||||
quant_type=self.quant_type,
|
quant_type=self.quant_type,
|
||||||
tensor_parallel_devices=self.tensor_parallel_devices,
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
||||||
should_validate_reachability=self.should_validate_reachability,
|
should_validate_reachability=self.should_validate_reachability,
|
||||||
@ -409,7 +409,7 @@ class ModuleContainer(threading.Thread):
|
|||||||
update_period: float,
|
update_period: float,
|
||||||
expiration: Optional[float],
|
expiration: Optional[float],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
use_auth_token: Optional[str],
|
token: Optional[str],
|
||||||
quant_type: QuantType,
|
quant_type: QuantType,
|
||||||
tensor_parallel_devices: Sequence[torch.device],
|
tensor_parallel_devices: Sequence[torch.device],
|
||||||
should_validate_reachability: bool,
|
should_validate_reachability: bool,
|
||||||
@ -443,7 +443,7 @@ class ModuleContainer(threading.Thread):
|
|||||||
config=block_config,
|
config=block_config,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
use_auth_token=use_auth_token,
|
token=token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
max_disk_space=max_disk_space,
|
max_disk_space=max_disk_space,
|
||||||
)
|
)
|
||||||
@ -456,7 +456,7 @@ class ModuleContainer(threading.Thread):
|
|||||||
quant_type,
|
quant_type,
|
||||||
adapters=server_info.adapters,
|
adapters=server_info.adapters,
|
||||||
freeze=True,
|
freeze=True,
|
||||||
use_auth_token=use_auth_token,
|
token=token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
max_disk_space=max_disk_space,
|
max_disk_space=max_disk_space,
|
||||||
)
|
)
|
||||||
|
@ -45,13 +45,20 @@ def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", d
|
|||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
|
def get_adapter_from_repo(
|
||||||
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
|
repo_id: str,
|
||||||
|
block_idx: Optional[int] = None,
|
||||||
|
device: Optional[int] = None,
|
||||||
|
*,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
||||||
config = PeftConfig.from_json_file(config_path)
|
config = PeftConfig.from_json_file(config_path)
|
||||||
|
|
||||||
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
|
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs)
|
||||||
if weight_path is None:
|
if weight_path is None:
|
||||||
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
||||||
if block_idx is None:
|
if block_idx is None:
|
||||||
@ -65,7 +72,7 @@ def load_peft(
|
|||||||
device: Optional[int] = None,
|
device: Optional[int] = None,
|
||||||
*,
|
*,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
use_auth_token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
cache_dir: str,
|
cache_dir: str,
|
||||||
max_disk_space: Optional[int] = None,
|
max_disk_space: Optional[int] = None,
|
||||||
delay: float = 30,
|
delay: float = 30,
|
||||||
@ -82,7 +89,7 @@ def load_peft(
|
|||||||
block_idx,
|
block_idx,
|
||||||
device,
|
device,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
use_auth_token=use_auth_token,
|
token=token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
)
|
)
|
||||||
@ -93,9 +100,9 @@ def load_peft(
|
|||||||
try:
|
try:
|
||||||
with allow_cache_writes(cache_dir):
|
with allow_cache_writes(cache_dir):
|
||||||
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
||||||
config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
|
config_file_size = get_hf_file_metadata(config_url, token=token).size
|
||||||
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
||||||
weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
|
weight_file_size = get_hf_file_metadata(weight_url, token=token).size
|
||||||
|
|
||||||
file_size = config_file_size + weight_file_size
|
file_size = config_file_size + weight_file_size
|
||||||
if file_size is not None:
|
if file_size is not None:
|
||||||
@ -108,7 +115,7 @@ def load_peft(
|
|||||||
block_idx,
|
block_idx,
|
||||||
device,
|
device,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
use_auth_token=use_auth_token,
|
token=token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user