From e297ae606f3894912ca038b269bc277c18843a08 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Wed, 3 Aug 2022 12:47:36 +0300 Subject: [PATCH] black --- src/bloom/from_pretrained.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/bloom/from_pretrained.py b/src/bloom/from_pretrained.py index e893ea4..b8bd398 100644 --- a/src/bloom/from_pretrained.py +++ b/src/bloom/from_pretrained.py @@ -34,13 +34,15 @@ def load_pretrained_block( config: Optional[BloomConfig] = None, torch_dtype: Union[torch.dtype, str] = "auto", use_auth_token: Optional[str] = None, - cache_dir: Optional[str] = None + cache_dir: Optional[str] = None, ) -> BloomBlock: """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it.""" if config is None: config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token) block = BloomBlock(config, layer_number=block_index) - state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir) + state_dict = _load_state_dict( + converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir + ) block.load_state_dict(state_dict) if torch_dtype == "auto": @@ -58,10 +60,10 @@ def load_pretrained_block( def _load_state_dict( - pretrained_model_name_or_path: str, - block_index: Optional[int] = None, - use_auth_token: Optional[str] = None, - cache_dir: Optional[str] = None + pretrained_model_name_or_path: str, + block_index: Optional[int] = None, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, ) -> OrderedDict[str, torch.Tensor]: revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)