mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
feature: better torch installation experience
This commit is contained in:
parent
71d4992dca
commit
24f4af3482
55
.github/workflows/ci.yaml
vendored
55
.github/workflows/ci.yaml
vendored
@ -7,11 +7,6 @@ on:
|
||||
workflow_dispatch:
|
||||
env:
|
||||
PIP_DISABLE_PIP_VERSION_CHECK: 1
|
||||
CACHE_PATHS: |
|
||||
~/.cache/huggingface
|
||||
~/.cache/clip
|
||||
~/.cache/imaginairy
|
||||
~/.cache/torch
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
@ -54,61 +49,27 @@ jobs:
|
||||
run: |
|
||||
black --diff --fast .
|
||||
test:
|
||||
runs-on: [self-hosted, cuda]
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
# subset: ["1/10", "2/10", "3/10", "4/10", "5/10", "6/10", "7/10", "8/10", "9/10", "10/10"]
|
||||
os: ["nvidia-4090"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: requirements-dev.txt
|
||||
# cache: pip
|
||||
# cache-dependency-path: requirements-dev.txt
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -r requirements-dev.txt .
|
||||
# - name: Get current date
|
||||
# id: date
|
||||
# run: echo "::set-output name=curmonth::$(date +'%Y-%m')"
|
||||
# - name: Cache Model Files
|
||||
# id: cache-model-files
|
||||
# uses: actions/cache/restore@v3
|
||||
# with:
|
||||
# path: ${{ env.CACHE_PATHS }}
|
||||
# key: ${{ steps.date.outputs.curmonth }}-b
|
||||
# Generate initial file list for all directories
|
||||
# - name: Generate initial model file list
|
||||
# run: |
|
||||
# for dir in $CACHE_PATHS; do
|
||||
# if [ -d "$dir" ]; then
|
||||
# find $dir
|
||||
# fi
|
||||
# done > initial_file_list.txt
|
||||
python -m pip uninstall torch torchvision xformers triton imaginairy -y
|
||||
python -m pip install -r requirements-dev.in . --upgrade
|
||||
- name: Test with pytest
|
||||
timeout-minutes: 30
|
||||
env:
|
||||
CUDA_LAUNCH_BLOCKING: 1
|
||||
run: |
|
||||
pytest --durations=10 -v
|
||||
# Generate final file list and check for new files
|
||||
# - name: Generate final model file list
|
||||
# run: |
|
||||
# for dir in CACHE_PATHS; do
|
||||
# if [ -d "$dir" ]; then
|
||||
# find $dir
|
||||
# fi
|
||||
# done > final_file_list.txt
|
||||
# if ! diff initial_file_list.txt final_file_list.txt > /dev/null; then
|
||||
# echo "New files detected."
|
||||
# echo "new_files=true" >> $GITHUB_ENV
|
||||
# else
|
||||
# echo "No new files detected."
|
||||
# fi
|
||||
# - uses: actions/cache/save@v3
|
||||
# id: cache
|
||||
# if: env.new_files == 'true'
|
||||
# with:
|
||||
# path: ${{ env.CACHE_PATHS }}
|
||||
# key: ${{ steps.date.outputs.curmonth }}-b
|
@ -82,6 +82,7 @@ Options:
|
||||
cutting edge features (SDXL, image prompts, etc) which will be added to imaginairy soon.
|
||||
- [self-attention guidance](https://github.com/SusungHong/Self-Attention-Guidance) which makes details of images more accurate
|
||||
- feature: added `--size` parameter for more intuitive sizing (e.g. 512, 256x256, 4k, uhd, FHD, VGA, etc)
|
||||
- feature: detect if wrong torch version is installed and provide instructions on how to install proper version
|
||||
- feature: better logging output: color, error handling
|
||||
- feature: support for pytorch 2.0
|
||||
- deprecated: support for python 3.8, 3.9
|
||||
|
@ -172,7 +172,10 @@ def imagine(
|
||||
num_prompts = "?"
|
||||
|
||||
if get_device() == "cpu":
|
||||
logger.info("Running in CPU mode. it's gonna be slooooooow.")
|
||||
logger.warning("Running in CPU mode. It's gonna be slooooooow.")
|
||||
from imaginairy.utils.torch_installer import torch_version_check
|
||||
|
||||
torch_version_check()
|
||||
|
||||
if half_mode is None:
|
||||
half_mode = "cuda" in get_device() or get_device() == "mps"
|
||||
|
@ -67,12 +67,19 @@ def system_info():
|
||||
"""
|
||||
Display system information. Submit this when reporting bugs.
|
||||
"""
|
||||
from imaginairy.debug_info import get_debug_info
|
||||
from imaginairy.utils.debug_info import get_debug_info
|
||||
|
||||
for k, v in get_debug_info().items():
|
||||
debug_info = get_debug_info()
|
||||
|
||||
for k, v in debug_info.items():
|
||||
if k == "nvidia_smi":
|
||||
continue
|
||||
k += ":"
|
||||
click.secho(f"{k: <30} {v}")
|
||||
|
||||
if "nvidia_smi" in debug_info:
|
||||
click.secho(debug_info["nvidia_smi"])
|
||||
|
||||
|
||||
@aimg.command("model-list")
|
||||
def model_list_cmd():
|
||||
|
@ -10,9 +10,9 @@ def run_server_cmd():
|
||||
"""Run a HTTP API server."""
|
||||
import uvicorn
|
||||
|
||||
from imaginairy.cli.shared import imaginairy_click_context
|
||||
from imaginairy.http_app.app import app
|
||||
from imaginairy.log_utils import configure_logging
|
||||
|
||||
configure_logging(level="DEBUG")
|
||||
logger.info("Starting HTTP API server at http://0.0.0.0:8000")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
||||
with imaginairy_click_context(log_level="DEBUG"):
|
||||
logger.info("Starting HTTP API server at http://0.0.0.0:8000")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
||||
|
@ -10,13 +10,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def imaginairy_click_context():
|
||||
def imaginairy_click_context(log_level="INFO"):
|
||||
from pydantic import ValidationError
|
||||
|
||||
from imaginairy.log_utils import configure_logging
|
||||
|
||||
errors_to_catch = (FileNotFoundError, ValidationError)
|
||||
configure_logging()
|
||||
configure_logging(level=log_level)
|
||||
try:
|
||||
yield
|
||||
except errors_to_catch as e:
|
||||
|
@ -1,5 +1,9 @@
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
def get_debug_info():
|
||||
import os.path
|
||||
import platform
|
||||
import sys
|
||||
|
||||
import psutil
|
||||
@ -15,9 +19,10 @@ def get_debug_info():
|
||||
"python_installation_path": sys.executable,
|
||||
"device": get_device(),
|
||||
"torch_version": torch.__version__,
|
||||
"platform": sys.platform,
|
||||
"platform": platform.system(),
|
||||
"hardware_description": get_hardware_description(get_device()),
|
||||
"ram_gb": round(psutil.virtual_memory().total / (1024**3), 2),
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
device_props = torch.cuda.get_device_properties(0)
|
||||
@ -36,4 +41,38 @@ def get_debug_info():
|
||||
"graphics_card": "Apple MPS",
|
||||
}
|
||||
)
|
||||
|
||||
nvidia_data = get_nvidia_smi_data()
|
||||
data.update(nvidia_data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_nvidia_smi_output():
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
"nvidia-smi", text=True, shell=True, stderr=subprocess.DEVNULL
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
output = "no nvidia card found"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _process_nvidia_smi_output(output):
|
||||
import re
|
||||
|
||||
cuda_version = re.search(r"CUDA Version: (\d+\.\d+)", output)
|
||||
|
||||
return {
|
||||
"cuda_version": cuda_version.group(1) if cuda_version else None,
|
||||
"nvidia_smi": output,
|
||||
}
|
||||
|
||||
|
||||
def get_nvidia_smi_data():
|
||||
smi_output = _get_nvidia_smi_output()
|
||||
return _process_nvidia_smi_output(smi_output)
|
138
imaginairy/utils/torch_installer.py
Normal file
138
imaginairy/utils/torch_installer.py
Normal file
@ -0,0 +1,138 @@
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def torch_version_check():
|
||||
if not could_install_better_torch_version():
|
||||
return
|
||||
|
||||
import platform
|
||||
|
||||
import torch
|
||||
|
||||
from imaginairy.utils.debug_info import get_nvidia_smi_data
|
||||
|
||||
nvidia_data = get_nvidia_smi_data()
|
||||
|
||||
cuda_version = Version(nvidia_data["cuda_version"])
|
||||
cmd_parts = generate_torch_install_command(
|
||||
installed_cuda_version=cuda_version, system_type=platform.system().lower()
|
||||
)
|
||||
cmd_str = " ".join(cmd_parts)
|
||||
linebreak = "*" * 72
|
||||
msg = (
|
||||
f"\n{linebreak}\n"
|
||||
f"torch=={torch.__version__} is installed and unable to use CUDA {cuda_version}.\n\n"
|
||||
"You can install the correct version by running:\n\n"
|
||||
f" pip uninstall torch torchvision -y\n"
|
||||
f" {cmd_str}\n\n"
|
||||
"Installing the correct version will speed up image generation.\n"
|
||||
f"{linebreak}\n"
|
||||
)
|
||||
logger.warning(msg)
|
||||
|
||||
|
||||
def could_install_better_torch_version():
|
||||
import platform
|
||||
|
||||
if platform.system().lower() not in ("windows", "linux"):
|
||||
return False
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
from imaginairy.utils.debug_info import get_nvidia_smi_data
|
||||
|
||||
nvidia_data = get_nvidia_smi_data()
|
||||
cuda_version = nvidia_data["cuda_version"]
|
||||
if cuda_version is None:
|
||||
return False
|
||||
|
||||
cuda_version = Version(cuda_version)
|
||||
determine_torch_index(
|
||||
installed_cuda_version=cuda_version, system_type=platform.system()
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def determine_torch_index(installed_cuda_version: Version, system_type: str):
|
||||
cuda_pypi_base_url = "https://download.pytorch.org/whl/"
|
||||
min_required_cuda_version = Version("11.8")
|
||||
system_type = system_type.lower()
|
||||
|
||||
if installed_cuda_version < min_required_cuda_version:
|
||||
msg = f"Your CUDA version ({installed_cuda_version}) is too old. Please upgrade to at least CUDA {min_required_cuda_version}."
|
||||
raise ValueError(msg)
|
||||
|
||||
if system_type == "windows":
|
||||
if installed_cuda_version >= Version("12.1"):
|
||||
return f"{cuda_pypi_base_url}cu121"
|
||||
if installed_cuda_version >= Version("12.0"):
|
||||
raise ValueError("You should upgrade to CUDA>=12.1")
|
||||
if installed_cuda_version >= Version("11.8"):
|
||||
return f"{cuda_pypi_base_url}cu118"
|
||||
|
||||
elif system_type == "linux":
|
||||
if installed_cuda_version >= Version("12.1"):
|
||||
return ""
|
||||
if installed_cuda_version >= Version("12.0"):
|
||||
raise ValueError("You should upgrade to CUDA>=12.1")
|
||||
if installed_cuda_version >= Version("11.8"):
|
||||
return f"{cuda_pypi_base_url}cu118"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def generate_torch_install_command(installed_cuda_version: Version, system_type):
|
||||
packages = ["torch", "torchvision"]
|
||||
index_url = determine_torch_index(
|
||||
installed_cuda_version=installed_cuda_version, system_type=system_type
|
||||
)
|
||||
cmd_parts = [
|
||||
"pip",
|
||||
"install",
|
||||
"--upgrade",
|
||||
]
|
||||
if index_url:
|
||||
cmd_parts.extend(
|
||||
[
|
||||
"--index-url",
|
||||
index_url,
|
||||
]
|
||||
)
|
||||
cmd_parts.extend(packages)
|
||||
return cmd_parts
|
||||
|
||||
|
||||
def install_packages(packages, index_url):
|
||||
"""
|
||||
Install a list of Python packages from a specified index server.
|
||||
|
||||
:param packages: A list of package names to install.
|
||||
:param index_url: The URL of the Python package index server.
|
||||
"""
|
||||
import sys
|
||||
|
||||
for package in packages:
|
||||
try:
|
||||
subprocess.check_call(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
package,
|
||||
"--index-url",
|
||||
index_url,
|
||||
]
|
||||
)
|
||||
print(f"Successfully installed {package}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Failed to install {package}: {e}")
|
@ -43,7 +43,7 @@ coverage==7.3.2
|
||||
# via -r requirements-dev.in
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
diffusers==0.23.1
|
||||
diffusers==0.24.0
|
||||
# via imaginAIry (setup.py)
|
||||
einops==0.7.0
|
||||
# via imaginAIry (setup.py)
|
||||
@ -65,13 +65,13 @@ filelock==3.13.1
|
||||
# transformers
|
||||
filterpy==1.4.5
|
||||
# via facexlib
|
||||
fonttools==4.45.1
|
||||
fonttools==4.46.0
|
||||
# via matplotlib
|
||||
frozenlist==1.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2023.10.0
|
||||
fsspec[http]==2023.12.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# pytorch-lightning
|
||||
@ -96,7 +96,7 @@ idna==3.6
|
||||
# yarl
|
||||
imageio==2.33.0
|
||||
# via imaginAIry (setup.py)
|
||||
importlib-metadata==6.8.0
|
||||
importlib-metadata==7.0.0
|
||||
# via diffusers
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
@ -170,6 +170,7 @@ packaging==23.2
|
||||
# pytest
|
||||
# pytest-sugar
|
||||
# pytorch-lightning
|
||||
# torchmetrics
|
||||
# transformers
|
||||
pathspec==0.11.2
|
||||
# via black
|
||||
@ -182,7 +183,7 @@ pillow==10.1.0
|
||||
# matplotlib
|
||||
# refiners
|
||||
# torchvision
|
||||
platformdirs==4.0.0
|
||||
platformdirs==4.1.0
|
||||
# via black
|
||||
pluggy==1.3.0
|
||||
# via pytest
|
||||
@ -239,7 +240,7 @@ requests==2.31.0
|
||||
# transformers
|
||||
responses==0.24.1
|
||||
# via -r requirements-dev.in
|
||||
ruff==0.1.6
|
||||
ruff==0.1.7
|
||||
# via -r requirements-dev.in
|
||||
safetensors==0.3.3
|
||||
# via
|
||||
@ -264,7 +265,7 @@ starlette==0.27.0
|
||||
# via fastapi
|
||||
sympy==1.12
|
||||
# via torch
|
||||
termcolor==2.3.0
|
||||
termcolor==2.4.0
|
||||
# via pytest-sugar
|
||||
timm==0.9.12
|
||||
# via
|
||||
@ -291,7 +292,7 @@ torch==2.1.1
|
||||
# torchvision
|
||||
torchdiffeq==0.2.3
|
||||
# via imaginAIry (setup.py)
|
||||
torchmetrics==1.2.0
|
||||
torchmetrics==1.2.1
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# pytorch-lightning
|
||||
|
@ -153,3 +153,18 @@ def pytest_collection_modifyitems(config, items):
|
||||
filtered_node_ids.sort()
|
||||
for n in filtered_node_ids:
|
||||
print(f" {n}")
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
from imaginairy.utils.debug_info import get_debug_info
|
||||
|
||||
debug_info = get_debug_info()
|
||||
|
||||
for k, v in debug_info.items():
|
||||
if k == "nvidia_smi":
|
||||
continue
|
||||
k += ":"
|
||||
print(f"{k: <30} {v}")
|
||||
|
||||
if "nvidia_smi" in debug_info:
|
||||
print(debug_info["nvidia_smi"])
|
||||
|
21
tests/test_utils/test_torch_installer.py
Normal file
21
tests/test_utils/test_torch_installer.py
Normal file
@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
|
||||
from imaginairy.utils.torch_installer import determine_torch_index
|
||||
|
||||
index_base = "https://download.pytorch.org/whl/"
|
||||
index_cu118 = f"{index_base}cu118"
|
||||
index_cu121 = f"{index_base}cu121"
|
||||
|
||||
torch_index_fixture = [
|
||||
(Version("11.8"), "linux", index_cu118),
|
||||
(Version("12.1"), "linux", ""),
|
||||
(Version("12.2"), "linux", ""),
|
||||
(Version("12.1"), "windows", index_cu121),
|
||||
(Version("12.2"), "windows", index_cu121),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("cuda_version", "platform", "expected"), torch_index_fixture)
|
||||
def test_determine_torch_index(cuda_version, platform, expected):
|
||||
assert determine_torch_index(cuda_version, platform) == expected
|
Loading…
Reference in New Issue
Block a user