mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
feature: better torch installation experience
This commit is contained in:
parent
71d4992dca
commit
24f4af3482
55
.github/workflows/ci.yaml
vendored
55
.github/workflows/ci.yaml
vendored
@ -7,11 +7,6 @@ on:
|
|||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
env:
|
env:
|
||||||
PIP_DISABLE_PIP_VERSION_CHECK: 1
|
PIP_DISABLE_PIP_VERSION_CHECK: 1
|
||||||
CACHE_PATHS: |
|
|
||||||
~/.cache/huggingface
|
|
||||||
~/.cache/clip
|
|
||||||
~/.cache/imaginairy
|
|
||||||
~/.cache/torch
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
lint:
|
||||||
@ -54,61 +49,27 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
black --diff --fast .
|
black --diff --fast .
|
||||||
test:
|
test:
|
||||||
runs-on: [self-hosted, cuda]
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.10"]
|
python-version: ["3.10"]
|
||||||
# subset: ["1/10", "2/10", "3/10", "4/10", "5/10", "6/10", "7/10", "8/10", "9/10", "10/10"]
|
os: ["nvidia-4090"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
cache: pip
|
# cache: pip
|
||||||
cache-dependency-path: requirements-dev.txt
|
# cache-dependency-path: requirements-dev.txt
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install -r requirements-dev.txt .
|
python -m pip uninstall torch torchvision xformers triton imaginairy -y
|
||||||
# - name: Get current date
|
python -m pip install -r requirements-dev.in . --upgrade
|
||||||
# id: date
|
|
||||||
# run: echo "::set-output name=curmonth::$(date +'%Y-%m')"
|
|
||||||
# - name: Cache Model Files
|
|
||||||
# id: cache-model-files
|
|
||||||
# uses: actions/cache/restore@v3
|
|
||||||
# with:
|
|
||||||
# path: ${{ env.CACHE_PATHS }}
|
|
||||||
# key: ${{ steps.date.outputs.curmonth }}-b
|
|
||||||
# Generate initial file list for all directories
|
|
||||||
# - name: Generate initial model file list
|
|
||||||
# run: |
|
|
||||||
# for dir in $CACHE_PATHS; do
|
|
||||||
# if [ -d "$dir" ]; then
|
|
||||||
# find $dir
|
|
||||||
# fi
|
|
||||||
# done > initial_file_list.txt
|
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
|
env:
|
||||||
|
CUDA_LAUNCH_BLOCKING: 1
|
||||||
run: |
|
run: |
|
||||||
pytest --durations=10 -v
|
pytest --durations=10 -v
|
||||||
# Generate final file list and check for new files
|
|
||||||
# - name: Generate final model file list
|
|
||||||
# run: |
|
|
||||||
# for dir in CACHE_PATHS; do
|
|
||||||
# if [ -d "$dir" ]; then
|
|
||||||
# find $dir
|
|
||||||
# fi
|
|
||||||
# done > final_file_list.txt
|
|
||||||
# if ! diff initial_file_list.txt final_file_list.txt > /dev/null; then
|
|
||||||
# echo "New files detected."
|
|
||||||
# echo "new_files=true" >> $GITHUB_ENV
|
|
||||||
# else
|
|
||||||
# echo "No new files detected."
|
|
||||||
# fi
|
|
||||||
# - uses: actions/cache/save@v3
|
|
||||||
# id: cache
|
|
||||||
# if: env.new_files == 'true'
|
|
||||||
# with:
|
|
||||||
# path: ${{ env.CACHE_PATHS }}
|
|
||||||
# key: ${{ steps.date.outputs.curmonth }}-b
|
|
@ -82,6 +82,7 @@ Options:
|
|||||||
cutting edge features (SDXL, image prompts, etc) which will be added to imaginairy soon.
|
cutting edge features (SDXL, image prompts, etc) which will be added to imaginairy soon.
|
||||||
- [self-attention guidance](https://github.com/SusungHong/Self-Attention-Guidance) which makes details of images more accurate
|
- [self-attention guidance](https://github.com/SusungHong/Self-Attention-Guidance) which makes details of images more accurate
|
||||||
- feature: added `--size` parameter for more intuitive sizing (e.g. 512, 256x256, 4k, uhd, FHD, VGA, etc)
|
- feature: added `--size` parameter for more intuitive sizing (e.g. 512, 256x256, 4k, uhd, FHD, VGA, etc)
|
||||||
|
- feature: detect if wrong torch version is installed and provide instructions on how to install proper version
|
||||||
- feature: better logging output: color, error handling
|
- feature: better logging output: color, error handling
|
||||||
- feature: support for pytorch 2.0
|
- feature: support for pytorch 2.0
|
||||||
- deprecated: support for python 3.8, 3.9
|
- deprecated: support for python 3.8, 3.9
|
||||||
|
@ -172,7 +172,10 @@ def imagine(
|
|||||||
num_prompts = "?"
|
num_prompts = "?"
|
||||||
|
|
||||||
if get_device() == "cpu":
|
if get_device() == "cpu":
|
||||||
logger.info("Running in CPU mode. it's gonna be slooooooow.")
|
logger.warning("Running in CPU mode. It's gonna be slooooooow.")
|
||||||
|
from imaginairy.utils.torch_installer import torch_version_check
|
||||||
|
|
||||||
|
torch_version_check()
|
||||||
|
|
||||||
if half_mode is None:
|
if half_mode is None:
|
||||||
half_mode = "cuda" in get_device() or get_device() == "mps"
|
half_mode = "cuda" in get_device() or get_device() == "mps"
|
||||||
|
@ -67,12 +67,19 @@ def system_info():
|
|||||||
"""
|
"""
|
||||||
Display system information. Submit this when reporting bugs.
|
Display system information. Submit this when reporting bugs.
|
||||||
"""
|
"""
|
||||||
from imaginairy.debug_info import get_debug_info
|
from imaginairy.utils.debug_info import get_debug_info
|
||||||
|
|
||||||
for k, v in get_debug_info().items():
|
debug_info = get_debug_info()
|
||||||
|
|
||||||
|
for k, v in debug_info.items():
|
||||||
|
if k == "nvidia_smi":
|
||||||
|
continue
|
||||||
k += ":"
|
k += ":"
|
||||||
click.secho(f"{k: <30} {v}")
|
click.secho(f"{k: <30} {v}")
|
||||||
|
|
||||||
|
if "nvidia_smi" in debug_info:
|
||||||
|
click.secho(debug_info["nvidia_smi"])
|
||||||
|
|
||||||
|
|
||||||
@aimg.command("model-list")
|
@aimg.command("model-list")
|
||||||
def model_list_cmd():
|
def model_list_cmd():
|
||||||
|
@ -10,9 +10,9 @@ def run_server_cmd():
|
|||||||
"""Run a HTTP API server."""
|
"""Run a HTTP API server."""
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
from imaginairy.cli.shared import imaginairy_click_context
|
||||||
from imaginairy.http_app.app import app
|
from imaginairy.http_app.app import app
|
||||||
from imaginairy.log_utils import configure_logging
|
|
||||||
|
|
||||||
configure_logging(level="DEBUG")
|
with imaginairy_click_context(log_level="DEBUG"):
|
||||||
logger.info("Starting HTTP API server at http://0.0.0.0:8000")
|
logger.info("Starting HTTP API server at http://0.0.0.0:8000")
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
|
||||||
|
@ -10,13 +10,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def imaginairy_click_context():
|
def imaginairy_click_context(log_level="INFO"):
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from imaginairy.log_utils import configure_logging
|
from imaginairy.log_utils import configure_logging
|
||||||
|
|
||||||
errors_to_catch = (FileNotFoundError, ValidationError)
|
errors_to_catch = (FileNotFoundError, ValidationError)
|
||||||
configure_logging()
|
configure_logging(level=log_level)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
except errors_to_catch as e:
|
except errors_to_catch as e:
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
def get_debug_info():
|
def get_debug_info():
|
||||||
import os.path
|
import os.path
|
||||||
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
@ -15,9 +19,10 @@ def get_debug_info():
|
|||||||
"python_installation_path": sys.executable,
|
"python_installation_path": sys.executable,
|
||||||
"device": get_device(),
|
"device": get_device(),
|
||||||
"torch_version": torch.__version__,
|
"torch_version": torch.__version__,
|
||||||
"platform": sys.platform,
|
"platform": platform.system(),
|
||||||
"hardware_description": get_hardware_description(get_device()),
|
"hardware_description": get_hardware_description(get_device()),
|
||||||
"ram_gb": round(psutil.virtual_memory().total / (1024**3), 2),
|
"ram_gb": round(psutil.virtual_memory().total / (1024**3), 2),
|
||||||
|
"cuda_available": torch.cuda.is_available(),
|
||||||
}
|
}
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device_props = torch.cuda.get_device_properties(0)
|
device_props = torch.cuda.get_device_properties(0)
|
||||||
@ -36,4 +41,38 @@ def get_debug_info():
|
|||||||
"graphics_card": "Apple MPS",
|
"graphics_card": "Apple MPS",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
nvidia_data = get_nvidia_smi_data()
|
||||||
|
data.update(nvidia_data)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _get_nvidia_smi_output():
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = subprocess.check_output(
|
||||||
|
"nvidia-smi", text=True, shell=True, stderr=subprocess.DEVNULL
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
output = "no nvidia card found"
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _process_nvidia_smi_output(output):
|
||||||
|
import re
|
||||||
|
|
||||||
|
cuda_version = re.search(r"CUDA Version: (\d+\.\d+)", output)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"cuda_version": cuda_version.group(1) if cuda_version else None,
|
||||||
|
"nvidia_smi": output,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_nvidia_smi_data():
|
||||||
|
smi_output = _get_nvidia_smi_output()
|
||||||
|
return _process_nvidia_smi_output(smi_output)
|
138
imaginairy/utils/torch_installer.py
Normal file
138
imaginairy/utils/torch_installer.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def torch_version_check():
|
||||||
|
if not could_install_better_torch_version():
|
||||||
|
return
|
||||||
|
|
||||||
|
import platform
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from imaginairy.utils.debug_info import get_nvidia_smi_data
|
||||||
|
|
||||||
|
nvidia_data = get_nvidia_smi_data()
|
||||||
|
|
||||||
|
cuda_version = Version(nvidia_data["cuda_version"])
|
||||||
|
cmd_parts = generate_torch_install_command(
|
||||||
|
installed_cuda_version=cuda_version, system_type=platform.system().lower()
|
||||||
|
)
|
||||||
|
cmd_str = " ".join(cmd_parts)
|
||||||
|
linebreak = "*" * 72
|
||||||
|
msg = (
|
||||||
|
f"\n{linebreak}\n"
|
||||||
|
f"torch=={torch.__version__} is installed and unable to use CUDA {cuda_version}.\n\n"
|
||||||
|
"You can install the correct version by running:\n\n"
|
||||||
|
f" pip uninstall torch torchvision -y\n"
|
||||||
|
f" {cmd_str}\n\n"
|
||||||
|
"Installing the correct version will speed up image generation.\n"
|
||||||
|
f"{linebreak}\n"
|
||||||
|
)
|
||||||
|
logger.warning(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def could_install_better_torch_version():
|
||||||
|
import platform
|
||||||
|
|
||||||
|
if platform.system().lower() not in ("windows", "linux"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
from imaginairy.utils.debug_info import get_nvidia_smi_data
|
||||||
|
|
||||||
|
nvidia_data = get_nvidia_smi_data()
|
||||||
|
cuda_version = nvidia_data["cuda_version"]
|
||||||
|
if cuda_version is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cuda_version = Version(cuda_version)
|
||||||
|
determine_torch_index(
|
||||||
|
installed_cuda_version=cuda_version, system_type=platform.system()
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def determine_torch_index(installed_cuda_version: Version, system_type: str):
|
||||||
|
cuda_pypi_base_url = "https://download.pytorch.org/whl/"
|
||||||
|
min_required_cuda_version = Version("11.8")
|
||||||
|
system_type = system_type.lower()
|
||||||
|
|
||||||
|
if installed_cuda_version < min_required_cuda_version:
|
||||||
|
msg = f"Your CUDA version ({installed_cuda_version}) is too old. Please upgrade to at least CUDA {min_required_cuda_version}."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if system_type == "windows":
|
||||||
|
if installed_cuda_version >= Version("12.1"):
|
||||||
|
return f"{cuda_pypi_base_url}cu121"
|
||||||
|
if installed_cuda_version >= Version("12.0"):
|
||||||
|
raise ValueError("You should upgrade to CUDA>=12.1")
|
||||||
|
if installed_cuda_version >= Version("11.8"):
|
||||||
|
return f"{cuda_pypi_base_url}cu118"
|
||||||
|
|
||||||
|
elif system_type == "linux":
|
||||||
|
if installed_cuda_version >= Version("12.1"):
|
||||||
|
return ""
|
||||||
|
if installed_cuda_version >= Version("12.0"):
|
||||||
|
raise ValueError("You should upgrade to CUDA>=12.1")
|
||||||
|
if installed_cuda_version >= Version("11.8"):
|
||||||
|
return f"{cuda_pypi_base_url}cu118"
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_torch_install_command(installed_cuda_version: Version, system_type):
|
||||||
|
packages = ["torch", "torchvision"]
|
||||||
|
index_url = determine_torch_index(
|
||||||
|
installed_cuda_version=installed_cuda_version, system_type=system_type
|
||||||
|
)
|
||||||
|
cmd_parts = [
|
||||||
|
"pip",
|
||||||
|
"install",
|
||||||
|
"--upgrade",
|
||||||
|
]
|
||||||
|
if index_url:
|
||||||
|
cmd_parts.extend(
|
||||||
|
[
|
||||||
|
"--index-url",
|
||||||
|
index_url,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cmd_parts.extend(packages)
|
||||||
|
return cmd_parts
|
||||||
|
|
||||||
|
|
||||||
|
def install_packages(packages, index_url):
|
||||||
|
"""
|
||||||
|
Install a list of Python packages from a specified index server.
|
||||||
|
|
||||||
|
:param packages: A list of package names to install.
|
||||||
|
:param index_url: The URL of the Python package index server.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
for package in packages:
|
||||||
|
try:
|
||||||
|
subprocess.check_call(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"pip",
|
||||||
|
"install",
|
||||||
|
package,
|
||||||
|
"--index-url",
|
||||||
|
index_url,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(f"Successfully installed {package}")
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"Failed to install {package}: {e}")
|
@ -43,7 +43,7 @@ coverage==7.3.2
|
|||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
cycler==0.12.1
|
cycler==0.12.1
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
diffusers==0.23.1
|
diffusers==0.24.0
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
einops==0.7.0
|
einops==0.7.0
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
@ -65,13 +65,13 @@ filelock==3.13.1
|
|||||||
# transformers
|
# transformers
|
||||||
filterpy==1.4.5
|
filterpy==1.4.5
|
||||||
# via facexlib
|
# via facexlib
|
||||||
fonttools==4.45.1
|
fonttools==4.46.0
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
frozenlist==1.4.0
|
frozenlist==1.4.0
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# aiosignal
|
# aiosignal
|
||||||
fsspec[http]==2023.10.0
|
fsspec[http]==2023.12.0
|
||||||
# via
|
# via
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# pytorch-lightning
|
# pytorch-lightning
|
||||||
@ -96,7 +96,7 @@ idna==3.6
|
|||||||
# yarl
|
# yarl
|
||||||
imageio==2.33.0
|
imageio==2.33.0
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
importlib-metadata==6.8.0
|
importlib-metadata==7.0.0
|
||||||
# via diffusers
|
# via diffusers
|
||||||
iniconfig==2.0.0
|
iniconfig==2.0.0
|
||||||
# via pytest
|
# via pytest
|
||||||
@ -170,6 +170,7 @@ packaging==23.2
|
|||||||
# pytest
|
# pytest
|
||||||
# pytest-sugar
|
# pytest-sugar
|
||||||
# pytorch-lightning
|
# pytorch-lightning
|
||||||
|
# torchmetrics
|
||||||
# transformers
|
# transformers
|
||||||
pathspec==0.11.2
|
pathspec==0.11.2
|
||||||
# via black
|
# via black
|
||||||
@ -182,7 +183,7 @@ pillow==10.1.0
|
|||||||
# matplotlib
|
# matplotlib
|
||||||
# refiners
|
# refiners
|
||||||
# torchvision
|
# torchvision
|
||||||
platformdirs==4.0.0
|
platformdirs==4.1.0
|
||||||
# via black
|
# via black
|
||||||
pluggy==1.3.0
|
pluggy==1.3.0
|
||||||
# via pytest
|
# via pytest
|
||||||
@ -239,7 +240,7 @@ requests==2.31.0
|
|||||||
# transformers
|
# transformers
|
||||||
responses==0.24.1
|
responses==0.24.1
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
ruff==0.1.6
|
ruff==0.1.7
|
||||||
# via -r requirements-dev.in
|
# via -r requirements-dev.in
|
||||||
safetensors==0.3.3
|
safetensors==0.3.3
|
||||||
# via
|
# via
|
||||||
@ -264,7 +265,7 @@ starlette==0.27.0
|
|||||||
# via fastapi
|
# via fastapi
|
||||||
sympy==1.12
|
sympy==1.12
|
||||||
# via torch
|
# via torch
|
||||||
termcolor==2.3.0
|
termcolor==2.4.0
|
||||||
# via pytest-sugar
|
# via pytest-sugar
|
||||||
timm==0.9.12
|
timm==0.9.12
|
||||||
# via
|
# via
|
||||||
@ -291,7 +292,7 @@ torch==2.1.1
|
|||||||
# torchvision
|
# torchvision
|
||||||
torchdiffeq==0.2.3
|
torchdiffeq==0.2.3
|
||||||
# via imaginAIry (setup.py)
|
# via imaginAIry (setup.py)
|
||||||
torchmetrics==1.2.0
|
torchmetrics==1.2.1
|
||||||
# via
|
# via
|
||||||
# imaginAIry (setup.py)
|
# imaginAIry (setup.py)
|
||||||
# pytorch-lightning
|
# pytorch-lightning
|
||||||
|
@ -153,3 +153,18 @@ def pytest_collection_modifyitems(config, items):
|
|||||||
filtered_node_ids.sort()
|
filtered_node_ids.sort()
|
||||||
for n in filtered_node_ids:
|
for n in filtered_node_ids:
|
||||||
print(f" {n}")
|
print(f" {n}")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_sessionstart(session):
|
||||||
|
from imaginairy.utils.debug_info import get_debug_info
|
||||||
|
|
||||||
|
debug_info = get_debug_info()
|
||||||
|
|
||||||
|
for k, v in debug_info.items():
|
||||||
|
if k == "nvidia_smi":
|
||||||
|
continue
|
||||||
|
k += ":"
|
||||||
|
print(f"{k: <30} {v}")
|
||||||
|
|
||||||
|
if "nvidia_smi" in debug_info:
|
||||||
|
print(debug_info["nvidia_smi"])
|
||||||
|
21
tests/test_utils/test_torch_installer.py
Normal file
21
tests/test_utils/test_torch_installer.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import pytest
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
|
from imaginairy.utils.torch_installer import determine_torch_index
|
||||||
|
|
||||||
|
index_base = "https://download.pytorch.org/whl/"
|
||||||
|
index_cu118 = f"{index_base}cu118"
|
||||||
|
index_cu121 = f"{index_base}cu121"
|
||||||
|
|
||||||
|
torch_index_fixture = [
|
||||||
|
(Version("11.8"), "linux", index_cu118),
|
||||||
|
(Version("12.1"), "linux", ""),
|
||||||
|
(Version("12.2"), "linux", ""),
|
||||||
|
(Version("12.1"), "windows", index_cu121),
|
||||||
|
(Version("12.2"), "windows", index_cu121),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("cuda_version", "platform", "expected"), torch_index_fixture)
|
||||||
|
def test_determine_torch_index(cuda_version, platform, expected):
|
||||||
|
assert determine_torch_index(cuda_version, platform) == expected
|
Loading…
Reference in New Issue
Block a user