fix: use py3.7 compat lru_cache

- disable lint fixer that updates to newer syntax
This commit is contained in:
Bryce 2023-01-22 17:28:17 -08:00 committed by Bryce Drennan
parent fce9fe9225
commit 1c986d8644
11 changed files with 34 additions and 38 deletions

View File

@ -30,7 +30,7 @@ af: autoformat ## Alias for `autoformat`
autoformat: ## Run the autoformatter. autoformat: ## Run the autoformatter.
@pycln . --all --quiet --extend-exclude __init__\.py @pycln . --all --quiet --extend-exclude __init__\.py
@# ERA,T201 @# ERA,T201
@-ruff --extend-ignore ANN,ARG001,C90,DTZ,D100,D101,D102,D103,D202,D203,D212,D415,E501,RET504,S101,UP006,UP007 --extend-select C,D400,I,UP,W --unfixable T,ERA --fix-only . @-ruff --extend-ignore ANN,ARG001,C90,DTZ,D100,D101,D102,D103,D202,D203,D212,D415,E501,RET504,S101,UP006,UP007 --extend-select C,D400,I,W --unfixable T,ERA --fix-only .
@black . @black .
test: ## Run the tests. test: ## Run the tests.

View File

@ -14,7 +14,7 @@ from imaginairy.vendored.clipseg import CLIPDensePredT
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth" weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
@lru_cache @lru_cache()
def clip_mask_model(): def clip_mask_model():
from imaginairy.paths import PKG_ROOT # noqa from imaginairy.paths import PKG_ROOT # noqa

View File

@ -17,7 +17,7 @@ if "mps" in device:
BLIP_EVAL_SIZE = 384 BLIP_EVAL_SIZE = 384
@lru_cache @lru_cache()
def blip_model(): def blip_model():
from imaginairy.paths import PKG_ROOT # noqa from imaginairy.paths import PKG_ROOT # noqa

View File

@ -10,7 +10,7 @@ from imaginairy.vendored import clip
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
@lru_cache @lru_cache()
def get_model(): def get_model():
model_name = "ViT-L/14" model_name = "ViT-L/14"
model, preprocess = clip.load(model_name, device=device) model, preprocess = clip.load(model_name, device=device)

View File

@ -17,7 +17,7 @@ face_restore_device = torch.device("cuda" if torch.cuda.is_available() else "cpu
half_mode = face_restore_device == "cuda" half_mode = face_restore_device == "cuda"
@lru_cache @lru_cache()
def codeformer_model(): def codeformer_model():
model = CodeFormer( model = CodeFormer(
dim_embd=512, dim_embd=512,
@ -36,7 +36,7 @@ def codeformer_model():
return model return model
@lru_cache @lru_cache()
def face_restore_helper(): def face_restore_helper():
""" """
Provide a singleton of FaceRestoreHelper. Provide a singleton of FaceRestoreHelper.

View File

@ -15,7 +15,7 @@ formatter = Formatter()
PROMPT_EXPANSION_PATTERN = re.compile(r"[|a-z0-9_ -]+") PROMPT_EXPANSION_PATTERN = re.compile(r"[|a-z0-9_ -]+")
@lru_cache @lru_cache()
def prompt_library_filepaths(prompt_library_paths=None): def prompt_library_filepaths(prompt_library_paths=None):
"""Return all available category/filepath pairs.""" """Return all available category/filepath pairs."""
prompt_library_paths = [] if not prompt_library_paths else prompt_library_paths prompt_library_paths = [] if not prompt_library_paths else prompt_library_paths
@ -27,7 +27,7 @@ def prompt_library_filepaths(prompt_library_paths=None):
return combined_prompt_library_filepaths return combined_prompt_library_filepaths
@lru_cache @lru_cache()
def category_list(prompt_library_paths=None): def category_list(prompt_library_paths=None):
"""Return the names of available phrase-lists.""" """Return the names of available phrase-lists."""
categories = list(prompt_library_filepaths(prompt_library_paths).keys()) categories = list(prompt_library_filepaths(prompt_library_paths).keys())
@ -35,7 +35,7 @@ def category_list(prompt_library_paths=None):
return categories return categories
@lru_cache @lru_cache()
def prompt_library_filepath(library_path): def prompt_library_filepath(library_path):
lookup = {} lookup = {}

View File

@ -10,7 +10,7 @@ from imaginairy.vendored.basicsr.rrdbnet_arch import RRDBNet
from imaginairy.vendored.realesrgan import RealESRGANer from imaginairy.vendored.realesrgan import RealESRGANer
@lru_cache @lru_cache()
def realesrgan_upsampler(): def realesrgan_upsampler():
model = RRDBNet( model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4 num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4

View File

@ -117,7 +117,7 @@ class EnhancedStableDiffusionSafetyChecker(
return safety_results return safety_results
@lru_cache @lru_cache()
def safety_models(): def safety_models():
safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_id = "CompVis/stable-diffusion-safety-checker"
monkeypatch_safety_cosine_distance() monkeypatch_safety_cosine_distance()
@ -128,7 +128,7 @@ def safety_models():
return safety_feature_extractor, safety_checker return safety_feature_extractor, safety_checker
@lru_cache @lru_cache()
def monkeypatch_safety_cosine_distance(): def monkeypatch_safety_cosine_distance():
orig_cosine_distance = safety_checker_mod.cosine_distance orig_cosine_distance = safety_checker_mod.cosine_distance

View File

@ -13,7 +13,7 @@ from torch.overrides import handle_torch_function, has_torch_function_variadic
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@lru_cache @lru_cache()
def get_device() -> str: def get_device() -> str:
"""Return the best torch backend available.""" """Return the best torch backend available."""
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -25,7 +25,7 @@ def get_device() -> str:
return "cpu" return "cpu"
@lru_cache @lru_cache()
def get_hardware_description(device_type: str) -> str: def get_hardware_description(device_type: str) -> str:
"""Description of the hardware being used.""" """Description of the hardware being used."""
desc = platform.platform() desc = platform.platform()

View File

@ -7,14 +7,14 @@ import ftfy
import regex as re import regex as re
@lru_cache @lru_cache()
def default_bpe(): def default_bpe():
return os.path.join( return os.path.join(
os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
) )
@lru_cache @lru_cache()
def bytes_to_unicode(): def bytes_to_unicode():
""" """
Returns list of utf-8 byte and a corresponding list of unicode strings. Returns list of utf-8 byte and a corresponding list of unicode strings.

View File

@ -10,7 +10,7 @@ aiosignal==1.3.1
# via aiohttp # via aiohttp
antlr4-python3-runtime==4.9.3 antlr4-python3-runtime==4.9.3
# via omegaconf # via omegaconf
astroid==2.13.2 astroid==2.13.3
# via pylint # via pylint
async-timeout==4.0.2 async-timeout==4.0.2
# via aiohttp # via aiohttp
@ -34,9 +34,9 @@ click==8.1.3
# typer # typer
click-shell==2.1 click-shell==2.1
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
contourpy==1.0.6 contourpy==1.0.7
# via matplotlib # via matplotlib
coverage==7.0.4 coverage==7.0.5
# via -r requirements-dev.in # via -r requirements-dev.in
cycler==0.11.0 cycler==0.11.0
# via matplotlib # via matplotlib
@ -65,7 +65,7 @@ frozenlist==1.3.3
# via # via
# aiohttp # aiohttp
# aiosignal # aiosignal
fsspec[http]==2022.11.0 fsspec[http]==2023.1.0
# via pytorch-lightning # via pytorch-lightning
ftfy==6.1.1 ftfy==6.1.1
# via # via
@ -103,7 +103,7 @@ lightning-utilities==0.5.0
# via pytorch-lightning # via pytorch-lightning
llvmlite==0.39.1 llvmlite==0.39.1
# via numba # via numba
matplotlib==3.6.2 matplotlib==3.6.3
# via filterpy # via filterpy
mccabe==0.7.0 mccabe==0.7.0
# via # via
@ -133,13 +133,12 @@ numpy==1.23.5
# opencv-python # opencv-python
# pytorch-lightning # pytorch-lightning
# scipy # scipy
# tensorboardx
# torchmetrics # torchmetrics
# torchvision # torchvision
# transformers # transformers
omegaconf==2.3.0 omegaconf==2.3.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
open-clip-torch==2.9.2 open-clip-torch==2.9.3
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
opencv-python==4.7.0.68 opencv-python==4.7.0.68
# via # via
@ -156,7 +155,7 @@ packaging==23.0
# pytorch-lightning # pytorch-lightning
# torchmetrics # torchmetrics
# transformers # transformers
pathspec==0.9.0 pathspec==0.10.3
# via # via
# black # black
# pycln # pycln
@ -174,28 +173,27 @@ platformdirs==2.6.2
# pylint # pylint
pluggy==1.0.0 pluggy==1.0.0
# via pytest # via pytest
protobuf==3.20.1 protobuf==3.20.3
# via # via
# imaginAIry (setup.py) # imaginAIry (setup.py)
# open-clip-torch # open-clip-torch
# tensorboardx
psutil==5.9.4 psutil==5.9.4
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
pycln==2.1.2 pycln==2.1.3
# via -r requirements-dev.in # via -r requirements-dev.in
pycodestyle==2.10.0 pycodestyle==2.10.0
# via pylama # via pylama
pydocstyle==6.2.3 pydocstyle==6.3.0
# via pylama # via pylama
pyflakes==3.0.1 pyflakes==3.0.1
# via pylama # via pylama
pylama==8.4.1 pylama==8.4.1
# via -r requirements-dev.in # via -r requirements-dev.in
pylint==2.15.9 pylint==2.15.10
# via -r requirements-dev.in # via -r requirements-dev.in
pyparsing==3.0.9 pyparsing==3.0.9
# via matplotlib # via matplotlib
pytest==7.2.0 pytest==7.2.1
# via # via
# -r requirements-dev.in # -r requirements-dev.in
# pytest-randomly # pytest-randomly
@ -206,7 +204,7 @@ pytest-sugar==0.9.6
# via -r requirements-dev.in # via -r requirements-dev.in
python-dateutil==2.8.2 python-dateutil==2.8.2
# via matplotlib # via matplotlib
pytorch-lightning==1.8.6 pytorch-lightning==1.9.0
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
pyyaml==6.0 pyyaml==6.0
# via # via
@ -222,7 +220,7 @@ regex==2022.10.31
# diffusers # diffusers
# open-clip-torch # open-clip-torch
# transformers # transformers
requests==2.28.1 requests==2.28.2
# via # via
# diffusers # diffusers
# fsspec # fsspec
@ -233,9 +231,9 @@ requests==2.28.1
# transformers # transformers
responses==0.22.0 responses==0.22.0
# via -r requirements-dev.in # via -r requirements-dev.in
ruff==0.0.215 ruff==0.0.230
# via -r requirements-dev.in # via -r requirements-dev.in
safetensors==0.2.7 safetensors==0.2.8
# via imaginAIry (setup.py) # via imaginAIry (setup.py)
scipy==1.10.0 scipy==1.10.0
# via # via
@ -248,8 +246,6 @@ six==1.16.0
# via python-dateutil # via python-dateutil
snowballstemmer==2.2.0 snowballstemmer==2.2.0
# via pydocstyle # via pydocstyle
tensorboardx==2.5.1
# via pytorch-lightning
termcolor==2.2.0 termcolor==2.2.0
# via pytest-sugar # via pytest-sugar
timm==0.6.12 timm==0.6.12
@ -317,11 +313,11 @@ typing-extensions==4.4.0
# typing-inspect # typing-inspect
typing-inspect==0.8.0 typing-inspect==0.8.0
# via libcst # via libcst
urllib3==1.26.13 urllib3==1.26.14
# via # via
# requests # requests
# responses # responses
wcwidth==0.2.5 wcwidth==0.2.6
# via ftfy # via ftfy
wheel==0.38.4 wheel==0.38.4
# via -r requirements-dev.in # via -r requirements-dev.in