diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index dd5f6b1..78443eb 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -33,5 +33,7 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM dht_prefix = str(model_name_or_path) if "/" in dht_prefix: # If present, strip repository name to merge blocks hosted by different accounts dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :] + if not dht_prefix.endswith("-hf"): + dht_prefix += "-hf" logger.info(f"Using DHT prefix: {dht_prefix}") return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)