You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
import os
|
|
import re
|
|
from typing import Union
|
|
|
|
import requests
|
|
from hivemind.utils.logging import TextStyle, get_logger
|
|
from packaging.version import parse
|
|
|
|
import petals
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def validate_version() -> None:
|
|
logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}")
|
|
try:
|
|
r = requests.get("https://pypi.python.org/pypi/petals/json")
|
|
r.raise_for_status()
|
|
response = r.json()
|
|
|
|
versions = [parse(ver) for ver in response.get("releases")]
|
|
latest = max(ver for ver in versions if not ver.is_prerelease)
|
|
|
|
if parse(petals.__version__) < latest:
|
|
logger.info(
|
|
f"A newer version {latest} is available. Please upgrade with: "
|
|
f"{TextStyle.BOLD}pip install --upgrade petals{TextStyle.RESET}"
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True)
|
|
|
|
|
|
def get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike, None]) -> Union[str, os.PathLike, None]:
|
|
if model_name_or_path is None:
|
|
return None
|
|
|
|
match = re.fullmatch(r"(bigscience/.+)-petals", str(model_name_or_path))
|
|
if match is None:
|
|
return model_name_or_path
|
|
|
|
logger.info(
|
|
f"Loading model from {match.group(1)}, since Petals 1.2.0+ uses original repos instead of converted ones"
|
|
)
|
|
return match.group(1)
|