2022-09-25 02:42:54 +00:00
|
|
|
import logging
|
|
|
|
import os
|
2022-09-15 02:40:50 +00:00
|
|
|
import sys
|
|
|
|
|
|
|
|
import pytest
|
2022-09-25 02:42:54 +00:00
|
|
|
from urllib3 import HTTPConnectionPool
|
2022-09-15 02:40:50 +00:00
|
|
|
|
2022-09-16 06:06:59 +00:00
|
|
|
from imaginairy import api
|
2022-10-11 02:50:11 +00:00
|
|
|
from imaginairy.log_utils import suppress_annoying_logs_and_warnings
|
2022-09-22 05:38:44 +00:00
|
|
|
from imaginairy.utils import (
|
|
|
|
fix_torch_group_norm,
|
|
|
|
fix_torch_nn_layer_norm,
|
2022-09-28 00:04:16 +00:00
|
|
|
get_device,
|
2022-09-22 05:38:44 +00:00
|
|
|
platform_appropriate_autocast,
|
|
|
|
)
|
2022-09-25 02:42:54 +00:00
|
|
|
from tests import TESTS_FOLDER
|
2022-09-15 02:40:50 +00:00
|
|
|
|
|
|
|
if "pytest" in str(sys.argv):
|
|
|
|
suppress_annoying_logs_and_warnings()
|
|
|
|
|
2022-09-25 02:42:54 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2022-09-15 02:40:50 +00:00
|
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
|
|
def pre_setup():
|
|
|
|
api.IMAGINAIRY_SAFETY_MODE = "disabled"
|
2022-09-22 05:38:44 +00:00
|
|
|
suppress_annoying_logs_and_warnings()
|
2022-09-28 00:04:16 +00:00
|
|
|
# test_output_folder = f"{TESTS_FOLDER}/test_output"
|
|
|
|
|
|
|
|
# delete the testoutput folder and recreate it
|
|
|
|
# rmtree(test_output_folder)
|
2022-09-25 02:42:54 +00:00
|
|
|
os.makedirs(f"{TESTS_FOLDER}/test_output", exist_ok=True)
|
|
|
|
|
|
|
|
orig_urlopen = HTTPConnectionPool.urlopen
|
|
|
|
|
|
|
|
def urlopen_tattle(self, method, url, *args, **kwargs):
|
|
|
|
# traceback.print_stack()
|
|
|
|
print(os.environ.get("PYTEST_CURRENT_TEST"))
|
|
|
|
print(f"{method} {self.host}{url}")
|
2022-09-28 00:04:16 +00:00
|
|
|
result = orig_urlopen(self, method, url, *args, **kwargs)
|
|
|
|
print(f"{method} {self.host}{url} DONE")
|
|
|
|
return result
|
2022-09-25 02:42:54 +00:00
|
|
|
|
|
|
|
HTTPConnectionPool.urlopen = urlopen_tattle
|
|
|
|
|
2022-09-22 05:38:44 +00:00
|
|
|
with fix_torch_nn_layer_norm(), fix_torch_group_norm(), platform_appropriate_autocast():
|
2022-09-17 05:21:20 +00:00
|
|
|
yield
|
2022-09-28 00:04:16 +00:00
|
|
|
|
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
def reset_get_device():
|
|
|
|
get_device.cache_clear()
|
|
|
|
|
|
|
|
|
2022-09-28 00:04:16 +00:00
|
|
|
@pytest.fixture()
|
|
|
|
def filename_base_for_outputs(request):
|
|
|
|
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_{get_device()}_"
|
|
|
|
return filename_base
|