feature: better torch installation experience

This commit is contained in:
Bryce 2023-12-04 19:01:51 -08:00 committed by Bryce Drennan
parent 71d4992dca
commit 24f4af3482
11 changed files with 251 additions and 65 deletions

View File

@ -7,11 +7,6 @@ on:
workflow_dispatch: workflow_dispatch:
env: env:
PIP_DISABLE_PIP_VERSION_CHECK: 1 PIP_DISABLE_PIP_VERSION_CHECK: 1
CACHE_PATHS: |
~/.cache/huggingface
~/.cache/clip
~/.cache/imaginairy
~/.cache/torch
jobs: jobs:
lint: lint:
@ -54,61 +49,27 @@ jobs:
run: | run: |
black --diff --fast . black --diff --fast .
test: test:
runs-on: [self-hosted, cuda] runs-on: ${{ matrix.os }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.10"] python-version: ["3.10"]
# subset: ["1/10", "2/10", "3/10", "4/10", "5/10", "6/10", "7/10", "8/10", "9/10", "10/10"] os: ["nvidia-4090"]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: pip # cache: pip
cache-dependency-path: requirements-dev.txt # cache-dependency-path: requirements-dev.txt
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install -r requirements-dev.txt . python -m pip uninstall torch torchvision xformers triton imaginairy -y
# - name: Get current date python -m pip install -r requirements-dev.in . --upgrade
# id: date
# run: echo "::set-output name=curmonth::$(date +'%Y-%m')"
# - name: Cache Model Files
# id: cache-model-files
# uses: actions/cache/restore@v3
# with:
# path: ${{ env.CACHE_PATHS }}
# key: ${{ steps.date.outputs.curmonth }}-b
# Generate initial file list for all directories
# - name: Generate initial model file list
# run: |
# for dir in $CACHE_PATHS; do
# if [ -d "$dir" ]; then
# find $dir
# fi
# done > initial_file_list.txt
- name: Test with pytest - name: Test with pytest
timeout-minutes: 30 timeout-minutes: 30
env:
CUDA_LAUNCH_BLOCKING: 1
run: | run: |
pytest --durations=10 -v pytest --durations=10 -v
# Generate final file list and check for new files
# - name: Generate final model file list
# run: |
# for dir in CACHE_PATHS; do
# if [ -d "$dir" ]; then
# find $dir
# fi
# done > final_file_list.txt
# if ! diff initial_file_list.txt final_file_list.txt > /dev/null; then
# echo "New files detected."
# echo "new_files=true" >> $GITHUB_ENV
# else
# echo "No new files detected."
# fi
# - uses: actions/cache/save@v3
# id: cache
# if: env.new_files == 'true'
# with:
# path: ${{ env.CACHE_PATHS }}
# key: ${{ steps.date.outputs.curmonth }}-b

View File

@ -82,6 +82,7 @@ Options:
cutting edge features (SDXL, image prompts, etc) which will be added to imaginairy soon. cutting edge features (SDXL, image prompts, etc) which will be added to imaginairy soon.
- [self-attention guidance](https://github.com/SusungHong/Self-Attention-Guidance) which makes details of images more accurate - [self-attention guidance](https://github.com/SusungHong/Self-Attention-Guidance) which makes details of images more accurate
- feature: added `--size` parameter for more intuitive sizing (e.g. 512, 256x256, 4k, uhd, FHD, VGA, etc) - feature: added `--size` parameter for more intuitive sizing (e.g. 512, 256x256, 4k, uhd, FHD, VGA, etc)
- feature: detect if wrong torch version is installed and provide instructions on how to install proper version
- feature: better logging output: color, error handling - feature: better logging output: color, error handling
- feature: support for pytorch 2.0 - feature: support for pytorch 2.0
- deprecated: support for python 3.8, 3.9 - deprecated: support for python 3.8, 3.9

View File

@ -172,7 +172,10 @@ def imagine(
num_prompts = "?" num_prompts = "?"
if get_device() == "cpu": if get_device() == "cpu":
logger.info("Running in CPU mode. it's gonna be slooooooow.") logger.warning("Running in CPU mode. It's gonna be slooooooow.")
from imaginairy.utils.torch_installer import torch_version_check
torch_version_check()
if half_mode is None: if half_mode is None:
half_mode = "cuda" in get_device() or get_device() == "mps" half_mode = "cuda" in get_device() or get_device() == "mps"

View File

@ -67,12 +67,19 @@ def system_info():
""" """
Display system information. Submit this when reporting bugs. Display system information. Submit this when reporting bugs.
""" """
from imaginairy.debug_info import get_debug_info from imaginairy.utils.debug_info import get_debug_info
for k, v in get_debug_info().items(): debug_info = get_debug_info()
for k, v in debug_info.items():
if k == "nvidia_smi":
continue
k += ":" k += ":"
click.secho(f"{k: <30} {v}") click.secho(f"{k: <30} {v}")
if "nvidia_smi" in debug_info:
click.secho(debug_info["nvidia_smi"])
@aimg.command("model-list") @aimg.command("model-list")
def model_list_cmd(): def model_list_cmd():

View File

@ -10,9 +10,9 @@ def run_server_cmd():
"""Run a HTTP API server.""" """Run a HTTP API server."""
import uvicorn import uvicorn
from imaginairy.cli.shared import imaginairy_click_context
from imaginairy.http_app.app import app from imaginairy.http_app.app import app
from imaginairy.log_utils import configure_logging
configure_logging(level="DEBUG") with imaginairy_click_context(log_level="DEBUG"):
logger.info("Starting HTTP API server at http://0.0.0.0:8000") logger.info("Starting HTTP API server at http://0.0.0.0:8000")
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

View File

@ -10,13 +10,13 @@ logger = logging.getLogger(__name__)
@contextmanager @contextmanager
def imaginairy_click_context(): def imaginairy_click_context(log_level="INFO"):
from pydantic import ValidationError from pydantic import ValidationError
from imaginairy.log_utils import configure_logging from imaginairy.log_utils import configure_logging
errors_to_catch = (FileNotFoundError, ValidationError) errors_to_catch = (FileNotFoundError, ValidationError)
configure_logging() configure_logging(level=log_level)
try: try:
yield yield
except errors_to_catch as e: except errors_to_catch as e:

View File

@ -1,5 +1,9 @@
from functools import lru_cache
def get_debug_info(): def get_debug_info():
import os.path import os.path
import platform
import sys import sys
import psutil import psutil
@ -15,9 +19,10 @@ def get_debug_info():
"python_installation_path": sys.executable, "python_installation_path": sys.executable,
"device": get_device(), "device": get_device(),
"torch_version": torch.__version__, "torch_version": torch.__version__,
"platform": sys.platform, "platform": platform.system(),
"hardware_description": get_hardware_description(get_device()), "hardware_description": get_hardware_description(get_device()),
"ram_gb": round(psutil.virtual_memory().total / (1024**3), 2), "ram_gb": round(psutil.virtual_memory().total / (1024**3), 2),
"cuda_available": torch.cuda.is_available(),
} }
if torch.cuda.is_available(): if torch.cuda.is_available():
device_props = torch.cuda.get_device_properties(0) device_props = torch.cuda.get_device_properties(0)
@ -36,4 +41,38 @@ def get_debug_info():
"graphics_card": "Apple MPS", "graphics_card": "Apple MPS",
} }
) )
nvidia_data = get_nvidia_smi_data()
data.update(nvidia_data)
return data return data
@lru_cache
def _get_nvidia_smi_output():
import subprocess
try:
output = subprocess.check_output(
"nvidia-smi", text=True, shell=True, stderr=subprocess.DEVNULL
)
except subprocess.CalledProcessError:
output = "no nvidia card found"
return output
def _process_nvidia_smi_output(output):
import re
cuda_version = re.search(r"CUDA Version: (\d+\.\d+)", output)
return {
"cuda_version": cuda_version.group(1) if cuda_version else None,
"nvidia_smi": output,
}
def get_nvidia_smi_data():
smi_output = _get_nvidia_smi_output()
return _process_nvidia_smi_output(smi_output)

View File

@ -0,0 +1,138 @@
import logging
import subprocess
from packaging.version import Version
logger = logging.getLogger(__name__)
def torch_version_check():
if not could_install_better_torch_version():
return
import platform
import torch
from imaginairy.utils.debug_info import get_nvidia_smi_data
nvidia_data = get_nvidia_smi_data()
cuda_version = Version(nvidia_data["cuda_version"])
cmd_parts = generate_torch_install_command(
installed_cuda_version=cuda_version, system_type=platform.system().lower()
)
cmd_str = " ".join(cmd_parts)
linebreak = "*" * 72
msg = (
f"\n{linebreak}\n"
f"torch=={torch.__version__} is installed and unable to use CUDA {cuda_version}.\n\n"
"You can install the correct version by running:\n\n"
f" pip uninstall torch torchvision -y\n"
f" {cmd_str}\n\n"
"Installing the correct version will speed up image generation.\n"
f"{linebreak}\n"
)
logger.warning(msg)
def could_install_better_torch_version():
import platform
if platform.system().lower() not in ("windows", "linux"):
return False
import torch
if torch.cuda.is_available():
return False
from imaginairy.utils.debug_info import get_nvidia_smi_data
nvidia_data = get_nvidia_smi_data()
cuda_version = nvidia_data["cuda_version"]
if cuda_version is None:
return False
cuda_version = Version(cuda_version)
determine_torch_index(
installed_cuda_version=cuda_version, system_type=platform.system()
)
return True
def determine_torch_index(installed_cuda_version: Version, system_type: str):
cuda_pypi_base_url = "https://download.pytorch.org/whl/"
min_required_cuda_version = Version("11.8")
system_type = system_type.lower()
if installed_cuda_version < min_required_cuda_version:
msg = f"Your CUDA version ({installed_cuda_version}) is too old. Please upgrade to at least CUDA {min_required_cuda_version}."
raise ValueError(msg)
if system_type == "windows":
if installed_cuda_version >= Version("12.1"):
return f"{cuda_pypi_base_url}cu121"
if installed_cuda_version >= Version("12.0"):
raise ValueError("You should upgrade to CUDA>=12.1")
if installed_cuda_version >= Version("11.8"):
return f"{cuda_pypi_base_url}cu118"
elif system_type == "linux":
if installed_cuda_version >= Version("12.1"):
return ""
if installed_cuda_version >= Version("12.0"):
raise ValueError("You should upgrade to CUDA>=12.1")
if installed_cuda_version >= Version("11.8"):
return f"{cuda_pypi_base_url}cu118"
return ""
def generate_torch_install_command(installed_cuda_version: Version, system_type):
packages = ["torch", "torchvision"]
index_url = determine_torch_index(
installed_cuda_version=installed_cuda_version, system_type=system_type
)
cmd_parts = [
"pip",
"install",
"--upgrade",
]
if index_url:
cmd_parts.extend(
[
"--index-url",
index_url,
]
)
cmd_parts.extend(packages)
return cmd_parts
def install_packages(packages, index_url):
"""
Install a list of Python packages from a specified index server.
:param packages: A list of package names to install.
:param index_url: The URL of the Python package index server.
"""
import sys
for package in packages:
try:
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
package,
"--index-url",
index_url,
]
)
print(f"Successfully installed {package}")
except subprocess.CalledProcessError as e:
print(f"Failed to install {package}: {e}")

View File

@ -43,7 +43,7 @@ coverage==7.3.2
# via -r requirements-dev.in # via -r requirements-dev.in
cycler==0.12.1 cycler==0.12.1
# via matplotlib # via matplotlib
diffusers==0.23.1 diffusers==0.24.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
einops==0.7.0 einops==0.7.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
@ -65,13 +65,13 @@ filelock==3.13.1
# transformers # transformers
filterpy==1.4.5 filterpy==1.4.5
# via facexlib # via facexlib
fonttools==4.45.1 fonttools==4.46.0
# via matplotlib # via matplotlib
frozenlist==1.4.0 frozenlist==1.4.0
# via # via
# aiohttp # aiohttp
# aiosignal # aiosignal
fsspec[http]==2023.10.0 fsspec[http]==2023.12.0
# via # via
# huggingface-hub # huggingface-hub
# pytorch-lightning # pytorch-lightning
@ -96,7 +96,7 @@ idna==3.6
# yarl # yarl
imageio==2.33.0 imageio==2.33.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
importlib-metadata==6.8.0 importlib-metadata==7.0.0
# via diffusers # via diffusers
iniconfig==2.0.0 iniconfig==2.0.0
# via pytest # via pytest
@ -170,6 +170,7 @@ packaging==23.2
# pytest # pytest
# pytest-sugar # pytest-sugar
# pytorch-lightning # pytorch-lightning
# torchmetrics
# transformers # transformers
pathspec==0.11.2 pathspec==0.11.2
# via black # via black
@ -182,7 +183,7 @@ pillow==10.1.0
# matplotlib # matplotlib
# refiners # refiners
# torchvision # torchvision
platformdirs==4.0.0 platformdirs==4.1.0
# via black # via black
pluggy==1.3.0 pluggy==1.3.0
# via pytest # via pytest
@ -239,7 +240,7 @@ requests==2.31.0
# transformers # transformers
responses==0.24.1 responses==0.24.1
# via -r requirements-dev.in # via -r requirements-dev.in
ruff==0.1.6 ruff==0.1.7
# via -r requirements-dev.in # via -r requirements-dev.in
safetensors==0.3.3 safetensors==0.3.3
# via # via
@ -264,7 +265,7 @@ starlette==0.27.0
# via fastapi # via fastapi
sympy==1.12 sympy==1.12
# via torch # via torch
termcolor==2.3.0 termcolor==2.4.0
# via pytest-sugar # via pytest-sugar
timm==0.9.12 timm==0.9.12
# via # via
@ -291,7 +292,7 @@ torch==2.1.1
# torchvision # torchvision
torchdiffeq==0.2.3 torchdiffeq==0.2.3
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
torchmetrics==1.2.0 torchmetrics==1.2.1
# via # via
# imaginAIry (setup.py) # imaginAIry (setup.py)
# pytorch-lightning # pytorch-lightning

View File

@ -153,3 +153,18 @@ def pytest_collection_modifyitems(config, items):
filtered_node_ids.sort() filtered_node_ids.sort()
for n in filtered_node_ids: for n in filtered_node_ids:
print(f" {n}") print(f" {n}")
def pytest_sessionstart(session):
from imaginairy.utils.debug_info import get_debug_info
debug_info = get_debug_info()
for k, v in debug_info.items():
if k == "nvidia_smi":
continue
k += ":"
print(f"{k: <30} {v}")
if "nvidia_smi" in debug_info:
print(debug_info["nvidia_smi"])

View File

@ -0,0 +1,21 @@
import pytest
from packaging.version import Version
from imaginairy.utils.torch_installer import determine_torch_index
index_base = "https://download.pytorch.org/whl/"
index_cu118 = f"{index_base}cu118"
index_cu121 = f"{index_base}cu121"
torch_index_fixture = [
(Version("11.8"), "linux", index_cu118),
(Version("12.1"), "linux", ""),
(Version("12.2"), "linux", ""),
(Version("12.1"), "windows", index_cu121),
(Version("12.2"), "windows", index_cu121),
]
@pytest.mark.parametrize(("cuda_version", "platform", "expected"), torch_index_fixture)
def test_determine_torch_index(cuda_version, platform, expected):
assert determine_torch_index(cuda_version, platform) == expected