2023-09-29 08:13:50 +00:00
|
|
|
import contextlib
|
2023-12-11 06:22:51 +00:00
|
|
|
import csv
|
|
|
|
import gc
|
2022-09-25 02:42:54 +00:00
|
|
|
import logging
|
|
|
|
import os
|
2022-09-15 02:40:50 +00:00
|
|
|
import sys
|
2022-10-15 00:21:38 +00:00
|
|
|
from functools import partialmethod
|
2022-10-16 23:42:46 +00:00
|
|
|
from shutil import rmtree
|
2022-09-15 02:40:50 +00:00
|
|
|
|
|
|
|
import pytest
|
2022-10-11 04:43:32 +00:00
|
|
|
import responses
|
2023-12-11 06:22:51 +00:00
|
|
|
import torch.cuda
|
2022-10-15 00:21:38 +00:00
|
|
|
from tqdm import tqdm
|
2022-09-25 02:42:54 +00:00
|
|
|
from urllib3 import HTTPConnectionPool
|
2022-09-15 02:40:50 +00:00
|
|
|
|
2023-12-10 00:33:39 +00:00
|
|
|
from imaginairy import api
|
|
|
|
from imaginairy.api import imagine
|
|
|
|
from imaginairy.schema import ImaginePrompt
|
2022-09-22 05:38:44 +00:00
|
|
|
from imaginairy.utils import (
|
|
|
|
fix_torch_group_norm,
|
|
|
|
fix_torch_nn_layer_norm,
|
2022-09-28 00:04:16 +00:00
|
|
|
get_device,
|
2022-09-22 05:38:44 +00:00
|
|
|
platform_appropriate_autocast,
|
|
|
|
)
|
2023-12-15 21:40:10 +00:00
|
|
|
from imaginairy.utils.log_utils import (
|
|
|
|
configure_logging,
|
|
|
|
suppress_annoying_logs_and_warnings,
|
|
|
|
)
|
2022-09-25 02:42:54 +00:00
|
|
|
from tests import TESTS_FOLDER
|
2022-09-15 02:40:50 +00:00
|
|
|
|
|
|
|
if "pytest" in str(sys.argv):
|
|
|
|
suppress_annoying_logs_and_warnings()
|
|
|
|
|
2022-09-25 02:42:54 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2023-12-08 04:57:55 +00:00
|
|
|
# SOLVERS_FOR_TESTING = SOLVER_TYPE_OPTIONS
|
|
|
|
# if get_device() == "mps:0":
|
|
|
|
# SOLVERS_FOR_TESTING = ["plms", "k_euler_a"]
|
|
|
|
# elif get_device() == "cpu":
|
|
|
|
# SOLVERS_FOR_TESTING = []
|
2022-10-16 23:42:46 +00:00
|
|
|
|
2023-12-08 04:57:55 +00:00
|
|
|
SOLVERS_FOR_TESTING = ["ddim", "dpmpp"]
|
2023-11-16 03:46:56 +00:00
|
|
|
|
2022-09-15 02:40:50 +00:00
|
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
2023-09-29 08:13:50 +00:00
|
|
|
def _pre_setup():
|
2022-09-15 02:40:50 +00:00
|
|
|
api.IMAGINAIRY_SAFETY_MODE = "disabled"
|
2022-09-22 05:38:44 +00:00
|
|
|
suppress_annoying_logs_and_warnings()
|
2022-10-16 23:42:46 +00:00
|
|
|
test_output_folder = f"{TESTS_FOLDER}/test_output"
|
2022-09-28 00:04:16 +00:00
|
|
|
|
|
|
|
# delete the testoutput folder and recreate it
|
2023-09-29 08:13:50 +00:00
|
|
|
with contextlib.suppress(FileNotFoundError):
|
2022-10-16 23:42:46 +00:00
|
|
|
rmtree(test_output_folder)
|
|
|
|
os.makedirs(test_output_folder, exist_ok=True)
|
2022-09-25 02:42:54 +00:00
|
|
|
|
|
|
|
orig_urlopen = HTTPConnectionPool.urlopen
|
|
|
|
|
|
|
|
def urlopen_tattle(self, method, url, *args, **kwargs):
|
|
|
|
# traceback.print_stack()
|
2022-10-16 23:42:46 +00:00
|
|
|
# current_test = os.environ.get("PYTEST_CURRENT_TEST", "")
|
|
|
|
# print(f"{current_test} {method} {self.host}{url}")
|
2022-09-28 00:04:16 +00:00
|
|
|
result = orig_urlopen(self, method, url, *args, **kwargs)
|
2022-10-16 23:42:46 +00:00
|
|
|
|
2022-10-11 04:43:32 +00:00
|
|
|
# raise HTTPError("NO NETWORK CALLS")
|
2022-09-28 00:04:16 +00:00
|
|
|
return result
|
2022-09-25 02:42:54 +00:00
|
|
|
|
|
|
|
HTTPConnectionPool.urlopen = urlopen_tattle
|
2022-10-15 00:21:38 +00:00
|
|
|
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
|
2022-10-16 23:42:46 +00:00
|
|
|
|
|
|
|
# real_randn = torch.randn
|
|
|
|
# def randn_tattle(*args, **kwargs):
|
|
|
|
# print("RANDN CALL RANDN CALL")
|
|
|
|
# traceback.print_stack()
|
|
|
|
# return real_randn(*args, **kwargs)
|
|
|
|
#
|
|
|
|
# torch.randn = randn_tattle
|
2023-05-29 17:54:07 +00:00
|
|
|
configure_logging("DEBUG")
|
2022-09-25 02:42:54 +00:00
|
|
|
|
2022-09-22 05:38:44 +00:00
|
|
|
with fix_torch_nn_layer_norm(), fix_torch_group_norm(), platform_appropriate_autocast():
|
2022-09-17 05:21:20 +00:00
|
|
|
yield
|
2022-09-28 00:04:16 +00:00
|
|
|
|
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
@pytest.fixture(autouse=True)
|
2023-09-29 08:13:50 +00:00
|
|
|
def _reset_get_device():
|
2022-10-06 04:43:00 +00:00
|
|
|
get_device.cache_clear()
|
|
|
|
|
|
|
|
|
2022-09-28 00:04:16 +00:00
|
|
|
@pytest.fixture()
|
|
|
|
def filename_base_for_outputs(request):
|
2022-10-16 23:42:46 +00:00
|
|
|
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_"
|
|
|
|
return filename_base
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture()
|
|
|
|
def filename_base_for_orig_outputs(request):
|
|
|
|
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.originalname}_"
|
2022-09-28 00:04:16 +00:00
|
|
|
return filename_base
|
2022-10-11 04:43:32 +00:00
|
|
|
|
|
|
|
|
2023-12-08 04:57:55 +00:00
|
|
|
@pytest.fixture(params=SOLVERS_FOR_TESTING)
|
|
|
|
def solver_type(request):
|
2022-10-16 23:42:46 +00:00
|
|
|
return request.param
|
|
|
|
|
|
|
|
|
2023-09-29 08:13:50 +00:00
|
|
|
@pytest.fixture()
|
2022-10-11 04:43:32 +00:00
|
|
|
def mocked_responses():
|
|
|
|
with responses.RequestsMock() as rsps:
|
|
|
|
yield rsps
|
2022-10-16 23:42:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
|
|
parser.addoption(
|
|
|
|
"--subset",
|
|
|
|
action="store",
|
|
|
|
default=None,
|
|
|
|
help="Runs an exclusive subset of tests: '1/3', '2/3', '3/3'. Useful for distributed testing",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-12-03 15:02:57 +00:00
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
def default_model_loaded():
|
|
|
|
"""
|
|
|
|
Just to make sure default weights are downloaded before the test runs
|
|
|
|
|
|
|
|
"""
|
|
|
|
prompt = ImaginePrompt(
|
|
|
|
"dogs lying on a hot pink couch",
|
2023-12-08 04:57:55 +00:00
|
|
|
size=64,
|
2023-12-03 15:02:57 +00:00
|
|
|
steps=2,
|
|
|
|
seed=1,
|
2023-12-08 04:57:55 +00:00
|
|
|
solver_type="ddim",
|
2023-12-03 15:02:57 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
next(imagine(prompt))
|
|
|
|
|
|
|
|
|
2023-12-11 06:22:51 +00:00
|
|
|
cuda_tests_node_ids = []
|
|
|
|
cuda_test_tracker_filepath = f"{TESTS_FOLDER}/data/cuda-tests.csv"
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
def detect_cuda_tests(request):
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
start_memory = torch.cuda.max_memory_allocated()
|
|
|
|
yield
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
end_memory = torch.cuda.max_memory_allocated()
|
|
|
|
memory_diff = end_memory - start_memory
|
|
|
|
if memory_diff > 0:
|
2023-12-19 03:04:36 +00:00
|
|
|
test_id = request.node.nodeid
|
|
|
|
print(f"Test {test_id} used {memory_diff} bytes of GPU memory")
|
|
|
|
cuda_tests_node_ids.append(test_id)
|
2023-12-11 06:22:51 +00:00
|
|
|
|
2023-12-19 04:34:32 +00:00
|
|
|
torch.cuda.empty_cache()
|
|
|
|
gc.collect()
|
2023-12-11 06:22:51 +00:00
|
|
|
|
|
|
|
|
2022-10-16 23:42:46 +00:00
|
|
|
@pytest.hookimpl()
|
|
|
|
def pytest_collection_modifyitems(config, items):
|
|
|
|
"""Only select a subset of tests to run, based on the --subset option."""
|
2023-12-11 06:22:51 +00:00
|
|
|
|
|
|
|
node_ids_to_mark = read_stored_cuda_test_nodes()
|
|
|
|
for item in items:
|
|
|
|
if item.nodeid in node_ids_to_mark:
|
|
|
|
item.add_marker(pytest.mark.gputest)
|
|
|
|
|
2022-10-16 23:42:46 +00:00
|
|
|
filtered_node_ids = set()
|
|
|
|
node_ids = [f.nodeid for f in items]
|
|
|
|
node_ids.sort()
|
|
|
|
subset = config.getoption("--subset")
|
2023-12-11 06:22:51 +00:00
|
|
|
|
2022-10-16 23:42:46 +00:00
|
|
|
if subset:
|
|
|
|
partition_no, total_partitions = subset.split("/")
|
|
|
|
partition_no, total_partitions = int(partition_no), int(total_partitions)
|
|
|
|
if partition_no < 1 or partition_no > total_partitions:
|
|
|
|
raise ValueError("Invalid subset")
|
|
|
|
for i, node_id in enumerate(node_ids):
|
|
|
|
if i % total_partitions == partition_no - 1:
|
|
|
|
filtered_node_ids.add(node_id)
|
|
|
|
|
|
|
|
items[:] = [i for i in items if i.nodeid in filtered_node_ids]
|
|
|
|
|
|
|
|
print(
|
|
|
|
f"Running subset {partition_no}/{total_partitions} {len(filtered_node_ids)} tests:"
|
|
|
|
)
|
|
|
|
filtered_node_ids = list(filtered_node_ids)
|
|
|
|
filtered_node_ids.sort()
|
|
|
|
for n in filtered_node_ids:
|
|
|
|
print(f" {n}")
|
2023-12-05 03:01:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
def pytest_sessionstart(session):
|
|
|
|
from imaginairy.utils.debug_info import get_debug_info
|
|
|
|
|
|
|
|
debug_info = get_debug_info()
|
|
|
|
|
|
|
|
for k, v in debug_info.items():
|
|
|
|
if k == "nvidia_smi":
|
|
|
|
continue
|
|
|
|
k += ":"
|
|
|
|
print(f"{k: <30} {v}")
|
|
|
|
|
|
|
|
if "nvidia_smi" in debug_info:
|
|
|
|
print(debug_info["nvidia_smi"])
|
2023-12-11 06:22:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
def pytest_sessionfinish(session, exitstatus):
|
|
|
|
existing_node_ids = read_stored_cuda_test_nodes()
|
|
|
|
updated_node_ids = existing_node_ids.union(set(cuda_tests_node_ids))
|
|
|
|
|
|
|
|
# Write updated, sorted list of node IDs to file
|
|
|
|
with open(cuda_test_tracker_filepath, "w", newline="") as file:
|
2023-12-13 23:37:00 +00:00
|
|
|
writer = csv.writer(file, lineterminator="\n")
|
2023-12-11 06:22:51 +00:00
|
|
|
for node_id in sorted(updated_node_ids):
|
|
|
|
writer.writerow([node_id])
|
|
|
|
|
|
|
|
|
|
|
|
def read_stored_cuda_test_nodes():
|
|
|
|
node_ids = set()
|
|
|
|
try:
|
|
|
|
with open(cuda_test_tracker_filepath, newline="") as file:
|
|
|
|
reader = csv.reader(file)
|
|
|
|
for row in reader:
|
|
|
|
node_ids.add(row[0])
|
|
|
|
except FileNotFoundError:
|
|
|
|
pass # File does not exist yet
|
|
|
|
return node_ids
|