|
|
@ -31,6 +31,9 @@ class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig,
|
|
|
|
def from_pretrained(
|
|
|
|
def from_pretrained(
|
|
|
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
|
|
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
|
|
|
):
|
|
|
|
):
|
|
|
|
|
|
|
|
if "180B" in model_name_or_path.upper():
|
|
|
|
|
|
|
|
logger.info("Make sure you follow the Falcon-180B license: https://bit.ly/falcon-180b-license")
|
|
|
|
|
|
|
|
|
|
|
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
|
|
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
|
|
|
if loading_from_repo and dht_prefix is None:
|
|
|
|
if loading_from_repo and dht_prefix is None:
|
|
|
|
dht_prefix = str(model_name_or_path)
|
|
|
|
dht_prefix = str(model_name_or_path)
|
|
|
|