mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
lint: new ruff linter
This commit is contained in:
parent
a1871e9d3a
commit
1381c7fed4
3
Makefile
3
Makefile
@ -27,7 +27,8 @@ init: require_pyenv ## Setup a dev environment for local development.
|
||||
af: autoformat ## Alias for `autoformat`
|
||||
autoformat: ## Run the autoformatter.
|
||||
@pycln . --all --quiet --extend-exclude __init__\.py
|
||||
@isort --atomic --profile black .
|
||||
@# ERA,T201
|
||||
@-ruff --extend-ignore ANN,ARG001,C90,DTZ,D100,D101,D102,D103,D202,D203,D212,D415,E501,RET504,S101,UP006,UP007 --extend-select C,D400,I,UP,W --unfixable T,ERA --fix-only .
|
||||
@black .
|
||||
|
||||
test: ## Run the tests.
|
||||
|
@ -223,7 +223,7 @@ def imagine_cmd(
|
||||
model_weights_path,
|
||||
prompt_library_path,
|
||||
):
|
||||
"""Have the AI generate images. alias:imagine"""
|
||||
"""Have the AI generate images. alias:imagine."""
|
||||
if ctx.invoked_subcommand is not None:
|
||||
return
|
||||
|
||||
@ -303,7 +303,7 @@ def aimg():
|
||||
@click.argument("image_filepaths", nargs=-1)
|
||||
@aimg.command()
|
||||
def describe(image_filepaths):
|
||||
"""Generate text descriptions of images"""
|
||||
"""Generate text descriptions of images."""
|
||||
imgs = []
|
||||
for p in image_filepaths:
|
||||
if p.startswith("http"):
|
||||
|
@ -14,7 +14,7 @@ from imaginairy.vendored.clipseg import CLIPDensePredT
|
||||
weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def clip_mask_model():
|
||||
from imaginairy.paths import PKG_ROOT # noqa
|
||||
|
||||
|
@ -17,7 +17,7 @@ if "mps" in device:
|
||||
BLIP_EVAL_SIZE = 384
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def blip_model():
|
||||
from imaginairy.paths import PKG_ROOT # noqa
|
||||
|
||||
@ -35,7 +35,7 @@ def blip_model():
|
||||
|
||||
|
||||
def generate_caption(image, min_length=30):
|
||||
"""Given an image, return a caption"""
|
||||
"""Given an image, return a caption."""
|
||||
gpu_image = (
|
||||
transforms.Compose(
|
||||
[
|
||||
|
@ -10,7 +10,7 @@ from imaginairy.vendored import clip
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def get_model():
|
||||
model_name = "ViT-L/14"
|
||||
model, preprocess = clip.load(model_name, device=device)
|
||||
@ -18,7 +18,7 @@ def get_model():
|
||||
|
||||
|
||||
def find_img_text_similarity(image: Image.Image, phrases: Sequence):
|
||||
"""Find the likelihood of a list of textual concepts existing in the image"""
|
||||
"""Find the likelihood of a list of textual concepts existing in the image."""
|
||||
|
||||
model, preprocess = get_model()
|
||||
image = preprocess(image).unsqueeze(0).to(device)
|
||||
|
@ -14,7 +14,7 @@ from imaginairy.vendored.codeformer.codeformer_arch import CodeFormer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def codeformer_model():
|
||||
model = CodeFormer(
|
||||
dim_embd=512,
|
||||
@ -31,10 +31,10 @@ def codeformer_model():
|
||||
return model
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def face_restore_helper():
|
||||
"""
|
||||
Provide a singleton of FaceRestoreHelper
|
||||
Provide a singleton of FaceRestoreHelper.
|
||||
|
||||
FaceRestoreHelper loads a model internally so we need to cache it
|
||||
or we end up with a memory leak
|
||||
|
@ -15,9 +15,9 @@ formatter = Formatter()
|
||||
PROMPT_EXPANSION_PATTERN = re.compile(r"[|a-z0-9_ -]+")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def prompt_library_filepaths(prompt_library_paths=None):
|
||||
"""Return all available category/filepath pairs"""
|
||||
"""Return all available category/filepath pairs."""
|
||||
prompt_library_paths = [] if not prompt_library_paths else prompt_library_paths
|
||||
combined_prompt_library_filepaths = {}
|
||||
for prompt_path in DEFAULT_PROMPT_LIBRARY_PATHS + list(prompt_library_paths):
|
||||
@ -27,15 +27,15 @@ def prompt_library_filepaths(prompt_library_paths=None):
|
||||
return combined_prompt_library_filepaths
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def category_list(prompt_library_paths=None):
|
||||
"""Return the names of available phrase-lists"""
|
||||
"""Return the names of available phrase-lists."""
|
||||
categories = list(prompt_library_filepaths(prompt_library_paths).keys())
|
||||
categories.sort()
|
||||
return categories
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def prompt_library_filepath(library_path):
|
||||
lookup = {}
|
||||
|
||||
@ -70,7 +70,7 @@ def get_phrases(category_name, prompt_library_paths=None):
|
||||
|
||||
def expand_prompts(prompt_text, n=1, prompt_library_paths=None):
|
||||
"""
|
||||
Replaces {vars} with random samples of corresponding phraselists
|
||||
Replaces {vars} with random samples of corresponding phraselists.
|
||||
|
||||
Example:
|
||||
p = "a happy {animal}"
|
||||
|
@ -10,7 +10,7 @@ from imaginairy.model_manager import get_cached_url_path
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def realesrgan_upsampler():
|
||||
model = RRDBNet(
|
||||
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4
|
||||
|
@ -21,7 +21,10 @@ def pillow_fit_image_within(image: PIL.Image.Image, max_height=512, max_width=51
|
||||
|
||||
if resize_ratio != 1:
|
||||
w, h = int(w * resize_ratio), int(h * resize_ratio)
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
# resize to integer multiple of 64
|
||||
w -= w % 64
|
||||
h -= h % 64
|
||||
|
||||
if (w, h) != image.size:
|
||||
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
||||
return image
|
||||
|
@ -29,7 +29,7 @@ class HuggingFaceAuthorizationError(RuntimeError):
|
||||
|
||||
|
||||
class MemoryAwareModel:
|
||||
"""Wraps a model to allow dynamic loading/unloading as needed"""
|
||||
"""Wraps a model to allow dynamic loading/unloading as needed."""
|
||||
|
||||
def __init__(self, config_path, weights_path, half_mode=None):
|
||||
self._config_path = config_path
|
||||
@ -122,7 +122,7 @@ def get_diffusion_model(
|
||||
for_inpainting=False,
|
||||
):
|
||||
"""
|
||||
Load a diffusion model
|
||||
Load a diffusion model.
|
||||
|
||||
Weights location may also be shortcut name, e.g. "SD-1.5"
|
||||
"""
|
||||
@ -148,7 +148,7 @@ def _get_diffusion_model(
|
||||
for_inpainting=False,
|
||||
):
|
||||
"""
|
||||
Load a diffusion model
|
||||
Load a diffusion model.
|
||||
|
||||
Weights location may also be shortcut name, e.g. "SD-1.5"
|
||||
"""
|
||||
@ -217,7 +217,7 @@ def get_cache_dir():
|
||||
|
||||
def get_cached_url_path(url):
|
||||
"""
|
||||
Gets the contents of a url, but caches the response indefinitely
|
||||
Gets the contents of a url, but caches the response indefinitely.
|
||||
|
||||
While we attempt to use the cached_path from huggingface transformers, we fall back
|
||||
to our own implementation if the url does not provide an etag header, which `cached_path`
|
||||
|
@ -184,7 +184,7 @@ class CrossAttention(nn.Module):
|
||||
k = self.to_k(context) * self.scale
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
q, k, v = (rearrange(t, "b n (h d) -> (b h) n d", h=h) for t in (q, k, v))
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if ATTENTION_PRECISION_OVERRIDE == "fp32":
|
||||
@ -219,8 +219,8 @@ class CrossAttention(nn.Module):
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q_in, k_in, v_in)
|
||||
q, k, v = (
|
||||
rearrange(t, "b n (h d) -> (b h) n d", h=h) for t in (q_in, k_in, v_in)
|
||||
)
|
||||
del q_in, k_in, v_in
|
||||
|
||||
@ -300,13 +300,13 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
q, k, v = (
|
||||
t.unsqueeze(3)
|
||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
.contiguous()
|
||||
for t in (q, k, v)
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
@ -392,7 +392,7 @@ class SpatialTransformer(nn.Module):
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -168,7 +168,9 @@ class AutoencoderKL(pl.LightningModule):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
||||
log_dict_ema = self._validation_step( # noqa
|
||||
batch, batch_idx, postfix="_ema"
|
||||
)
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, postfix=""):
|
||||
|
@ -9,7 +9,7 @@ from imaginairy.vendored import clip
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(nn.Module):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -3,7 +3,7 @@ wild mixture of
|
||||
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
|
||||
https://github.com/CompVis/taming-transformers
|
||||
-- merci
|
||||
-- merci.
|
||||
"""
|
||||
import itertools
|
||||
import logging
|
||||
@ -66,7 +66,7 @@ class DDPM(pl.LightningModule):
|
||||
beta_schedule="linear",
|
||||
loss_type="l2",
|
||||
ckpt_path=None,
|
||||
ignore_keys=tuple(),
|
||||
ignore_keys=(),
|
||||
load_only_unet=False,
|
||||
monitor="val/loss",
|
||||
use_ema=True,
|
||||
@ -286,7 +286,7 @@ class DDPM(pl.LightningModule):
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
@torch.no_grad()
|
||||
def init_from_ckpt(self, path, ignore_keys=tuple(), only_model=False):
|
||||
def init_from_ckpt(self, path, ignore_keys=(), only_model=False):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
@ -664,7 +664,7 @@ def _TileModeConv2DConvForward(
|
||||
|
||||
|
||||
class LatentDiffusion(DDPM):
|
||||
"""main class"""
|
||||
"""main class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -728,7 +728,7 @@ class LatentDiffusion(DDPM):
|
||||
)
|
||||
|
||||
def tile_mode(self, tile_mode):
|
||||
"""For creating seamless tiles"""
|
||||
"""For creating seamless tiles."""
|
||||
tile_mode = tile_mode or ""
|
||||
tile_x = "x" in tile_mode
|
||||
tile_y = "y" in tile_mode
|
||||
@ -904,9 +904,12 @@ class LatentDiffusion(DDPM):
|
||||
Lx = (w - kernel_size[1]) // stride[1] + 1
|
||||
|
||||
if uf == 1 and df == 1:
|
||||
fold_params = dict(
|
||||
kernel_size=kernel_size, dilation=1, padding=0, stride=stride
|
||||
)
|
||||
fold_params = {
|
||||
"kernel_size": kernel_size,
|
||||
"dilation": 1,
|
||||
"padding": 0,
|
||||
"stride": stride,
|
||||
}
|
||||
unfold = torch.nn.Unfold(**fold_params)
|
||||
|
||||
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
||||
@ -918,17 +921,20 @@ class LatentDiffusion(DDPM):
|
||||
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
||||
|
||||
elif uf > 1 and df == 1:
|
||||
fold_params = dict(
|
||||
kernel_size=kernel_size, dilation=1, padding=0, stride=stride
|
||||
)
|
||||
fold_params = {
|
||||
"kernel_size": kernel_size,
|
||||
"dilation": 1,
|
||||
"padding": 0,
|
||||
"stride": stride,
|
||||
}
|
||||
unfold = torch.nn.Unfold(**fold_params)
|
||||
|
||||
fold_params2 = dict(
|
||||
kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
||||
dilation=1,
|
||||
padding=0,
|
||||
stride=(stride[0] * uf, stride[1] * uf),
|
||||
)
|
||||
fold_params2 = {
|
||||
"kernel_size": (kernel_size[0] * uf, kernel_size[0] * uf),
|
||||
"dilation": 1,
|
||||
"padding": 0,
|
||||
"stride": (stride[0] * uf, stride[1] * uf),
|
||||
}
|
||||
fold = torch.nn.Fold(
|
||||
output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2
|
||||
)
|
||||
@ -944,17 +950,20 @@ class LatentDiffusion(DDPM):
|
||||
)
|
||||
|
||||
elif df > 1 and uf == 1:
|
||||
fold_params = dict(
|
||||
kernel_size=kernel_size, dilation=1, padding=0, stride=stride
|
||||
)
|
||||
fold_params = {
|
||||
"kernel_size": kernel_size,
|
||||
"dilation": 1,
|
||||
"padding": 0,
|
||||
"stride": stride,
|
||||
}
|
||||
unfold = torch.nn.Unfold(**fold_params)
|
||||
|
||||
fold_params2 = dict(
|
||||
kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
||||
dilation=1,
|
||||
padding=0,
|
||||
stride=(stride[0] // df, stride[1] // df),
|
||||
)
|
||||
fold_params2 = {
|
||||
"kernel_size": (kernel_size[0] // df, kernel_size[0] // df),
|
||||
"dilation": 1,
|
||||
"padding": 0,
|
||||
"stride": (stride[0] // df, stride[1] // df),
|
||||
}
|
||||
fold = torch.nn.Fold(
|
||||
output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2
|
||||
)
|
||||
@ -1370,7 +1379,7 @@ class DiffusionWrapper(pl.LightningModule):
|
||||
class LatentFinetuneDiffusion(LatentDiffusion):
|
||||
"""
|
||||
Basis for different finetunas, such as inpainting or depth2image
|
||||
To disable finetuning mode, set finetune_keys to None
|
||||
To disable finetuning mode, set finetune_keys to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -1399,7 +1408,7 @@ class LatentFinetuneDiffusion(LatentDiffusion):
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=tuple(), only_model=False):
|
||||
def init_from_ckpt(self, path, ignore_keys=(), only_model=False):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
@ -1606,7 +1615,7 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
||||
|
||||
class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
|
||||
"""
|
||||
condition on monocular depth estimation
|
||||
condition on monocular depth estimation.
|
||||
"""
|
||||
|
||||
def __init__(self, depth_stage_config, concat_keys=("midas_in",), **kwargs):
|
||||
@ -1671,7 +1680,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
|
||||
|
||||
class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
||||
"""
|
||||
condition on low-res image (and optionally on some spatial noise augmentation)
|
||||
condition on low-res image (and optionally on some spatial noise augmentation).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -288,7 +288,7 @@ class MemoryEfficientAttnBlock(nn.Module):
|
||||
"""
|
||||
Uses xformers efficient implementation,
|
||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
Note: this is a single-head self-attention operation
|
||||
Note: this is a single-head self-attention operation.
|
||||
"""
|
||||
|
||||
#
|
||||
@ -320,16 +320,16 @@ class MemoryEfficientAttnBlock(nn.Module):
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
q, k, v = (rearrange(x, "b c h w -> b (h w) c") for x in (q, k, v))
|
||||
q, k, v = (
|
||||
t.unsqueeze(3)
|
||||
.reshape(B, t.shape[1], 1, C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * 1, t.shape[1], C)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
.contiguous()
|
||||
for t in (q, k, v)
|
||||
)
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op
|
||||
)
|
||||
@ -704,8 +704,7 @@ class Decoder(nn.Module):
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
@ -30,7 +30,7 @@ def convert_module_to_f32(_):
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -123,7 +123,7 @@ class Upsample(nn.Module):
|
||||
|
||||
|
||||
class TransposedUpsample(nn.Module):
|
||||
"""Learned 2x upsampling without padding"""
|
||||
"""Learned 2x upsampling without padding."""
|
||||
|
||||
def __init__(self, channels, out_channels=None, ks=5):
|
||||
super().__init__()
|
||||
@ -346,7 +346,7 @@ def count_flops_attn(model, _x, y):
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
).
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
@ -359,7 +359,7 @@ def count_flops_attn(model, _x, y):
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping.
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
@ -530,11 +530,10 @@ class UNetModel(nn.Module):
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(
|
||||
map(
|
||||
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
|
||||
range(len(num_attention_blocks)),
|
||||
)
|
||||
self.num_res_blocks[i] >= num_attention_blocks[i]
|
||||
for i in range(len(num_attention_blocks))
|
||||
)
|
||||
|
||||
print(
|
||||
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
|
@ -57,7 +57,7 @@ def make_beta_schedule(
|
||||
|
||||
|
||||
def frange(start, stop, step):
|
||||
"""Range but handles floats"""
|
||||
"""Range but handles floats."""
|
||||
x = start
|
||||
while True:
|
||||
if x >= stop:
|
||||
|
@ -54,7 +54,7 @@ def disabled_train(self, mode=True): # noqa
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
"""Uses the T5 transformer encoder for text."""
|
||||
|
||||
def __init__(
|
||||
self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
|
||||
@ -94,7 +94,7 @@ class FrozenT5Embedder(AbstractEncoder):
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)."""
|
||||
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
@ -155,7 +155,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
Uses the OpenCLIP transformer encoder for text.
|
||||
"""
|
||||
|
||||
LAYERS = [
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""
|
||||
MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -57,7 +57,7 @@ class Transpose(nn.Module):
|
||||
def forward_vit(pretrained, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
glob = pretrained.model.forward_flex(x)
|
||||
pretrained.model.forward_flex(x)
|
||||
|
||||
layer_1 = pretrained.activations["1"]
|
||||
layer_2 = pretrained.activations["2"]
|
||||
|
@ -105,15 +105,15 @@ def write_pfm(path, image, scale=1):
|
||||
else:
|
||||
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
||||
|
||||
file.write("PF\n" if color else "Pf\n".encode())
|
||||
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
||||
file.write("PF\n" if color else b"Pf\n")
|
||||
file.write(b"%d %d\n" % (image.shape[1], image.shape[0]))
|
||||
|
||||
endian = image.dtype.byteorder
|
||||
|
||||
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
||||
scale = -scale
|
||||
|
||||
file.write("%f\n".encode() % scale)
|
||||
file.write(b"%f\n" % scale)
|
||||
|
||||
image.tofile(file)
|
||||
|
||||
|
@ -115,7 +115,7 @@ class EnhancedStableDiffusionSafetyChecker(
|
||||
return safety_results
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def safety_models():
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
monkeypatch_safety_cosine_distance()
|
||||
@ -126,14 +126,14 @@ def safety_models():
|
||||
return safety_feature_extractor, safety_checker
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def monkeypatch_safety_cosine_distance():
|
||||
orig_cosine_distance = safety_checker_mod.cosine_distance
|
||||
|
||||
def cosine_distance_float32(image_embeds, text_embeds):
|
||||
"""
|
||||
In some environments we need to distance to be in float32
|
||||
but it was coming as BFloat16
|
||||
but it was coming as BFloat16.
|
||||
"""
|
||||
return orig_cosine_distance(image_embeds, text_embeds).to(torch.float32)
|
||||
|
||||
|
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DDIMSampler(ImageSampler):
|
||||
"""
|
||||
Denoising Diffusion Implicit Models
|
||||
Denoising Diffusion Implicit Models.
|
||||
|
||||
https://arxiv.org/abs/2010.02502
|
||||
"""
|
||||
|
@ -208,7 +208,7 @@ class LMSSampler(KDiffusionSampler):
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
"""
|
||||
Conditional forward guidance wrapper
|
||||
Conditional forward guidance wrapper.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
|
@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class PLMSSampler(ImageSampler):
|
||||
"""
|
||||
probabilistic least-mean-squares
|
||||
probabilistic least-mean-squares.
|
||||
|
||||
Provenance:
|
||||
https://github.com/CompVis/latent-diffusion/commit/f0c4e092c156986e125f48c61a0edd38ba8ad059
|
||||
|
@ -216,7 +216,7 @@ class ImaginePrompt:
|
||||
|
||||
|
||||
class ExifCodes:
|
||||
"""https://www.awaresystems.be/imaging/tiff/tifftags/baseline.html"""
|
||||
"""https://www.awaresystems.be/imaging/tiff/tifftags/baseline.html."""
|
||||
|
||||
ImageDescription = 0x010E
|
||||
Software = 0x0131
|
||||
|
@ -13,9 +13,9 @@ from torch.overrides import handle_torch_function, has_torch_function_variadic
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def get_device() -> str:
|
||||
"""Return the best torch backend available"""
|
||||
"""Return the best torch backend available."""
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
|
||||
@ -25,9 +25,9 @@ def get_device() -> str:
|
||||
return "cpu"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def get_hardware_description(device_type: str) -> str:
|
||||
"""Description of the hardware being used"""
|
||||
"""Description of the hardware being used."""
|
||||
desc = platform.platform()
|
||||
if device_type == "cuda":
|
||||
desc += "-" + torch.cuda.get_device_name(0)
|
||||
@ -37,7 +37,7 @@ def get_hardware_description(device_type: str) -> str:
|
||||
|
||||
def get_obj_from_str(import_path: str, reload=False) -> Any:
|
||||
"""
|
||||
Gets a python object from a string reference if it's location
|
||||
Gets a python object from a string reference if it's location.
|
||||
|
||||
Example: "functools.lru_cache"
|
||||
"""
|
||||
@ -50,7 +50,7 @@ def get_obj_from_str(import_path: str, reload=False) -> Any:
|
||||
|
||||
|
||||
def instantiate_from_config(config: Union[dict, str]) -> Any:
|
||||
"""Instantiate an object from a config dict"""
|
||||
"""Instantiate an object from a config dict."""
|
||||
if "target" not in config:
|
||||
if config == "__is_first_stage__":
|
||||
return None
|
||||
@ -65,7 +65,7 @@ def instantiate_from_config(config: Union[dict, str]) -> Any:
|
||||
@contextmanager
|
||||
def platform_appropriate_autocast(precision="autocast"):
|
||||
"""
|
||||
Allow calculations to run in mixed precision, which can be faster
|
||||
Allow calculations to run in mixed precision, which can be faster.
|
||||
"""
|
||||
precision_scope = nullcontext
|
||||
# autocast not supported on CPU
|
||||
@ -111,7 +111,7 @@ def _fixed_layer_norm(
|
||||
|
||||
@contextmanager
|
||||
def fix_torch_nn_layer_norm():
|
||||
"""https://github.com/CompVis/stable-diffusion/issues/25#issuecomment-1221416526"""
|
||||
"""https://github.com/CompVis/stable-diffusion/issues/25#issuecomment-1221416526."""
|
||||
orig_function = functional.layer_norm
|
||||
functional.layer_norm = _fixed_layer_norm
|
||||
try:
|
||||
@ -123,7 +123,7 @@ def fix_torch_nn_layer_norm():
|
||||
@contextmanager
|
||||
def fix_torch_group_norm():
|
||||
"""
|
||||
Patch group_norm to cast the weights to the same type as the inputs
|
||||
Patch group_norm to cast the weights to the same type as the inputs.
|
||||
|
||||
From what I can understand all the other repos just switch to full precision instead
|
||||
of addressing this. I think this would make things slower but I'm not sure.
|
||||
@ -158,7 +158,7 @@ def fix_torch_group_norm():
|
||||
|
||||
|
||||
def randn_seeded(seed: int, size: List[int]) -> Tensor:
|
||||
"""Generate a random tensor with a given seed"""
|
||||
"""Generate a random tensor with a given seed."""
|
||||
g_cpu = torch.Generator()
|
||||
g_cpu.manual_seed(seed)
|
||||
noise = torch.randn(
|
||||
@ -170,7 +170,7 @@ def randn_seeded(seed: int, size: List[int]) -> Tensor:
|
||||
|
||||
|
||||
def check_torch_working():
|
||||
"""Check that torch is working"""
|
||||
"""Check that torch is working."""
|
||||
try:
|
||||
torch.randn(1, device=get_device())
|
||||
except RuntimeError as e:
|
||||
|
@ -3,7 +3,7 @@
|
||||
* All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
* By Junnan Li
|
||||
* By Junnan Li.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
@ -34,7 +34,7 @@ class BLIP_Base(nn.Module):
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
vit (str): model size of vision transformer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -102,7 +102,7 @@ class BLIP_Decoder(nn.Module):
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
vit (str): model size of vision transformer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -19,7 +19,7 @@ class BLIP_ITM(nn.Module):
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
vit (str): model size of vision transformer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -21,7 +21,7 @@ class BLIP_NLVR(nn.Module):
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
vit (str): model size of vision transformer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -3,8 +3,10 @@
|
||||
* All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
* By Junnan Li
|
||||
* By Junnan Li.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import transformers
|
||||
from models.med import BertConfig, BertLMHeadModel, BertModel
|
||||
|
||||
@ -32,7 +34,7 @@ class BLIP_Pretrain(nn.Module):
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
vit (str): model size of vision transformer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -329,9 +331,6 @@ def concat_all_gather(tensor):
|
||||
return output
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def tie_encoder_decoder_weights(
|
||||
encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str
|
||||
):
|
||||
@ -368,9 +367,9 @@ def tie_encoder_decoder_weights(
|
||||
len(encoder_modules) > 0
|
||||
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
||||
|
||||
all_encoder_weights = set(
|
||||
[module_name + "/" + sub_name for sub_name in encoder_modules.keys()]
|
||||
)
|
||||
all_encoder_weights = {
|
||||
module_name + "/" + sub_name for sub_name in encoder_modules.keys()
|
||||
}
|
||||
encoder_layer_pos = 0
|
||||
for name, module in decoder_modules.items():
|
||||
if name.isdigit():
|
||||
|
@ -22,7 +22,7 @@ class BLIP_Retrieval(nn.Module):
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
vit (str): model size of vision transformer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -19,7 +19,7 @@ class BLIP_VQA(nn.Module):
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
vit (str): model size of vision transformer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
* By Junnan Li
|
||||
* Based on huggingface code base
|
||||
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
||||
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert.
|
||||
"""
|
||||
|
||||
import math
|
||||
@ -611,7 +611,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
@ -654,7 +654,7 @@ class BertModel(BertPreTrainedModel):
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
class PreTrainedModel.
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
@ -977,7 +977,7 @@ class BertLMHeadModel(BertPreTrainedModel):
|
||||
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> prediction_logits = outputs.logits
|
||||
>>> prediction_logits = outputs.logits.
|
||||
"""
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
@ -656,7 +656,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
@ -699,7 +699,7 @@ class BertModel(BertPreTrainedModel):
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
class PreTrainedModel.
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
@ -5,7 +5,7 @@
|
||||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
* By Junnan Li
|
||||
* Based on timm code base
|
||||
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
* https://github.com/rwightman/pytorch-image-models/tree/master/timm.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
@ -19,7 +19,7 @@ from timm.models.vision_transformer import PatchEmbed
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -156,7 +156,7 @@ class Block(nn.Module):
|
||||
class VisionTransformer(nn.Module):
|
||||
"""Vision Transformer
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
||||
https://arxiv.org/abs/2010.11929
|
||||
https://arxiv.org/abs/2010.11929.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -195,7 +195,7 @@ class VisionTransformer(nn.Module):
|
||||
drop_rate (float): dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
norm_layer: (nn.Module): normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_features = (
|
||||
@ -282,7 +282,7 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
|
||||
"""Load weights from .npz checkpoints for official Google Brain Flax implementation"""
|
||||
"""Load weights from .npz checkpoints for official Google Brain Flax implementation."""
|
||||
import numpy as np
|
||||
|
||||
def _n2p(w, t=True):
|
||||
|
@ -109,7 +109,7 @@ def _transform(n_px):
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available CLIP models"""
|
||||
"""Returns the names of available CLIP models."""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
@ -119,7 +119,7 @@ def load(
|
||||
jit: bool = False,
|
||||
download_root: str = None,
|
||||
):
|
||||
"""Load a CLIP model
|
||||
"""Load a CLIP model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -246,7 +246,7 @@ def tokenize(
|
||||
texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
|
||||
) -> Union[torch.IntTensor, torch.LongTensor]:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
Returns the tokenized representation of given input string(s).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -117,7 +117,7 @@ class ModifiedResNet(nn.Module):
|
||||
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
- The final pooling layer is a QKV attention instead of an average pool.
|
||||
"""
|
||||
|
||||
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
||||
@ -447,7 +447,7 @@ class CLIP(nn.Module):
|
||||
|
||||
|
||||
def convert_weights(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
"""Convert applicable model parameters to fp16."""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
@ -495,11 +495,11 @@ def build_model(state_dict: dict):
|
||||
else:
|
||||
counts: list = [
|
||||
len(
|
||||
set(
|
||||
{
|
||||
k.split(".")[2]
|
||||
for k in state_dict
|
||||
if k.startswith(f"visual.layer{b}")
|
||||
)
|
||||
}
|
||||
)
|
||||
for b in [1, 2, 3, 4]
|
||||
]
|
||||
@ -521,9 +521,7 @@ def build_model(state_dict: dict):
|
||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(
|
||||
set(
|
||||
k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")
|
||||
)
|
||||
{k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")}
|
||||
)
|
||||
|
||||
model = CLIP(
|
||||
|
@ -7,14 +7,14 @@ import ftfy
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def default_bpe():
|
||||
return os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@lru_cache
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
@ -65,7 +65,7 @@ def whitespace_clean(text):
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
|
@ -4,7 +4,6 @@ from os.path import basename, dirname, isfile, join
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as nnf
|
||||
from torch.nn.modules.activation import ReLU
|
||||
|
||||
|
||||
def precompute_clip_vectors():
|
||||
@ -182,7 +181,7 @@ class CLIPDenseBase(nn.Module):
|
||||
k: torch.from_numpy(v) for k, v in precomp.items()
|
||||
}
|
||||
else:
|
||||
self.precomputed_prompts = dict()
|
||||
self.precomputed_prompts = {}
|
||||
|
||||
def rescaled_pos_emb(self, new_size):
|
||||
assert len(new_size) == 2
|
||||
@ -383,11 +382,7 @@ def clip_load_untrained(version):
|
||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(
|
||||
set(
|
||||
k.split(".")[2]
|
||||
for k in state_dict
|
||||
if k.startswith(f"transformer.resblocks")
|
||||
)
|
||||
{k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")}
|
||||
)
|
||||
|
||||
return CLIP(
|
||||
@ -717,12 +712,11 @@ class CLIPSegMultiLabel(nn.Module):
|
||||
def __init__(self, model) -> None:
|
||||
super().__init__()
|
||||
|
||||
from third_party.JoEm.data_loader import VOC, get_seen_idx, get_unseen_idx
|
||||
from third_party.JoEm.data_loader import VOC
|
||||
|
||||
self.pascal_classes = VOC
|
||||
|
||||
from general_utils import load_model
|
||||
from models.clipseg import CLIPDensePredT
|
||||
|
||||
# self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False)
|
||||
self.clipseg = load_model(model, strict=False)
|
||||
|
@ -95,7 +95,7 @@ class PositionEmbeddingSine(nn.Module):
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
"""Return an activation function given a string."""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
@ -186,9 +186,7 @@ class CodeFormer(VQAutoEncoder):
|
||||
connect_list=["32", "64", "128", "256"],
|
||||
fix_modules=["quantize", "generator"],
|
||||
):
|
||||
super(CodeFormer, self).__init__(
|
||||
512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
|
||||
)
|
||||
super().__init__(512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size)
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""
|
||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py.
|
||||
|
||||
"""
|
||||
|
||||
@ -25,7 +25,7 @@ def swish(x):
|
||||
# Define VQVAE classes
|
||||
class VectorQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
@ -173,7 +173,7 @@ class Upsample(nn.Module):
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
super(ResBlock, self).__init__()
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm1 = normalize(in_channels)
|
||||
|
@ -1,12 +0,0 @@
|
||||
from . import (
|
||||
augmentation,
|
||||
config,
|
||||
evaluation,
|
||||
external,
|
||||
gns,
|
||||
layers,
|
||||
models,
|
||||
sampling,
|
||||
utils,
|
||||
)
|
||||
from .layers import Denoiser
|
@ -1 +0,0 @@
|
||||
from .image_v1 import ImageDenoiserModelV1
|
@ -56,18 +56,20 @@ class DBlock(layers.ConditionedSequential):
|
||||
)
|
||||
)
|
||||
if self_attn:
|
||||
norm = lambda c_in: layers.AdaGN(
|
||||
feats_in, c_in, max(1, my_c_out // group_size)
|
||||
)
|
||||
|
||||
def norm(c_in):
|
||||
return layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
|
||||
modules.append(
|
||||
layers.SelfAttention2d(
|
||||
my_c_out, max(1, my_c_out // head_size), norm, dropout_rate
|
||||
)
|
||||
)
|
||||
if cross_attn:
|
||||
norm = lambda c_in: layers.AdaGN(
|
||||
feats_in, c_in, max(1, my_c_out // group_size)
|
||||
)
|
||||
|
||||
def norm(c_in):
|
||||
return layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
|
||||
modules.append(
|
||||
layers.CrossAttention2d(
|
||||
my_c_out,
|
||||
@ -111,18 +113,20 @@ class UBlock(layers.ConditionedSequential):
|
||||
)
|
||||
)
|
||||
if self_attn:
|
||||
norm = lambda c_in: layers.AdaGN(
|
||||
feats_in, c_in, max(1, my_c_out // group_size)
|
||||
)
|
||||
|
||||
def norm(c_in):
|
||||
return layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
|
||||
modules.append(
|
||||
layers.SelfAttention2d(
|
||||
my_c_out, max(1, my_c_out // head_size), norm, dropout_rate
|
||||
)
|
||||
)
|
||||
if cross_attn:
|
||||
norm = lambda c_in: layers.AdaGN(
|
||||
feats_in, c_in, max(1, my_c_out // group_size)
|
||||
)
|
||||
|
||||
def norm(c_in):
|
||||
return layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
|
||||
|
||||
modules.append(
|
||||
layers.CrossAttention2d(
|
||||
my_c_out,
|
||||
|
@ -798,8 +798,12 @@ def sample_dpmpp_2s_ancestral(
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.to("cpu").log().neg().to(x.device)
|
||||
|
||||
def sigma_fn(t):
|
||||
return t.neg().exp()
|
||||
|
||||
def t_fn(sigma):
|
||||
return sigma.to("cpu").log().neg().to(x.device)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@ -856,8 +860,12 @@ def sample_dpmpp_sde(
|
||||
)
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.to("cpu").log().neg().to(x.device)
|
||||
|
||||
def sigma_fn(t):
|
||||
return t.neg().exp()
|
||||
|
||||
def t_fn(sigma):
|
||||
return sigma.to("cpu").log().neg().to(x.device)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@ -906,8 +914,13 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
"""DPM-Solver++(2M)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.to("cpu").log().neg().to(x.device)
|
||||
|
||||
def sigma_fn(t):
|
||||
return t.neg().exp()
|
||||
|
||||
def t_fn(sigma):
|
||||
return sigma.to("cpu").log().neg().to(x.device)
|
||||
|
||||
old_denoised = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
|
@ -1,7 +1,7 @@
|
||||
black
|
||||
coverage
|
||||
isort
|
||||
pycln
|
||||
ruff
|
||||
pydocstyle
|
||||
pylama
|
||||
pylint
|
||||
|
@ -20,7 +20,7 @@ astroid==2.12.13
|
||||
# via pylint
|
||||
async-timeout==4.0.2
|
||||
# via aiohttp
|
||||
attrs==22.1.0
|
||||
attrs==22.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
# pytest
|
||||
@ -28,11 +28,11 @@ basicsr==1.4.2
|
||||
# via
|
||||
# gfpgan
|
||||
# realesrgan
|
||||
black==22.10.0
|
||||
black==22.12.0
|
||||
# via -r requirements-dev.in
|
||||
cachetools==5.2.0
|
||||
# via google-auth
|
||||
certifi==2022.9.24
|
||||
certifi==2022.12.7
|
||||
# via requests
|
||||
charset-normalizer==2.1.1
|
||||
# via
|
||||
@ -43,30 +43,29 @@ click==8.1.3
|
||||
# black
|
||||
# click-shell
|
||||
# imaginAIry (setup.py)
|
||||
# typer
|
||||
click-shell==2.1
|
||||
# via imaginAIry (setup.py)
|
||||
contourpy==1.0.6
|
||||
# via matplotlib
|
||||
coverage==6.5.0
|
||||
coverage==7.0.1
|
||||
# via -r requirements-dev.in
|
||||
cycler==0.11.0
|
||||
# via matplotlib
|
||||
diffusers==0.8.1
|
||||
diffusers==0.11.1
|
||||
# via imaginAIry (setup.py)
|
||||
dill==0.3.6
|
||||
# via pylint
|
||||
einops==0.3.0
|
||||
# via imaginAIry (setup.py)
|
||||
exceptiongroup==1.0.4
|
||||
exceptiongroup==1.1.0
|
||||
# via pytest
|
||||
facexlib==0.2.5
|
||||
# via
|
||||
# gfpgan
|
||||
# realesrgan
|
||||
fairscale==0.4.12
|
||||
fairscale==0.4.13
|
||||
# via imaginAIry (setup.py)
|
||||
filelock==3.8.0
|
||||
filelock==3.9.0
|
||||
# via
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
@ -93,7 +92,7 @@ gfpgan==1.3.8
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# realesrgan
|
||||
google-auth==2.14.1
|
||||
google-auth==2.15.0
|
||||
# via
|
||||
# google-auth-oauthlib
|
||||
# tb-nightly
|
||||
@ -102,11 +101,11 @@ google-auth-oauthlib==0.4.6
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
grpcio==1.50.0
|
||||
grpcio==1.51.1
|
||||
# via
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
huggingface-hub==0.11.0
|
||||
huggingface-hub==0.11.1
|
||||
# via
|
||||
# diffusers
|
||||
# open-clip-torch
|
||||
@ -120,11 +119,11 @@ imageio==2.9.0
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# scikit-image
|
||||
importlib-metadata==5.1.0
|
||||
importlib-metadata==6.0.0
|
||||
# via diffusers
|
||||
iniconfig==1.1.1
|
||||
# via pytest
|
||||
isort==5.10.1
|
||||
isort==5.11.4
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
# pylint
|
||||
@ -134,11 +133,9 @@ kornia==0.6
|
||||
# via imaginAIry (setup.py)
|
||||
lazy-object-proxy==1.8.0
|
||||
# via astroid
|
||||
libcst==0.4.9
|
||||
# via pycln
|
||||
llvmlite==0.39.1
|
||||
# via numba
|
||||
lmdb==1.3.0
|
||||
lmdb==1.4.0
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
@ -154,14 +151,12 @@ mccabe==0.7.0
|
||||
# via
|
||||
# pylama
|
||||
# pylint
|
||||
multidict==6.0.2
|
||||
multidict==6.0.4
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
mypy-extensions==0.4.3
|
||||
# via
|
||||
# black
|
||||
# typing-inspect
|
||||
# via black
|
||||
networkx==2.8.8
|
||||
# via scikit-image
|
||||
numba==0.56.4
|
||||
@ -195,15 +190,15 @@ oauthlib==3.2.2
|
||||
# via requests-oauthlib
|
||||
omegaconf==2.1.1
|
||||
# via imaginAIry (setup.py)
|
||||
open-clip-torch==2.7.0
|
||||
open-clip-torch==2.9.1
|
||||
# via imaginAIry (setup.py)
|
||||
opencv-python==4.6.0.66
|
||||
opencv-python==4.7.0.68
|
||||
# via
|
||||
# basicsr
|
||||
# facexlib
|
||||
# gfpgan
|
||||
# realesrgan
|
||||
packaging==21.3
|
||||
packaging==22.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# kornia
|
||||
@ -214,11 +209,9 @@ packaging==21.3
|
||||
# scikit-image
|
||||
# torchmetrics
|
||||
# transformers
|
||||
pathspec==0.9.0
|
||||
# via
|
||||
# black
|
||||
# pycln
|
||||
pillow==9.3.0
|
||||
pathspec==0.10.3
|
||||
# via black
|
||||
pillow==9.4.0
|
||||
# via
|
||||
# basicsr
|
||||
# diffusers
|
||||
@ -229,7 +222,7 @@ pillow==9.3.0
|
||||
# realesrgan
|
||||
# scikit-image
|
||||
# torchvision
|
||||
platformdirs==2.5.4
|
||||
platformdirs==2.6.2
|
||||
# via
|
||||
# black
|
||||
# pylint
|
||||
@ -238,6 +231,7 @@ pluggy==1.0.0
|
||||
protobuf==3.20.3
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# open-clip-torch
|
||||
# tb-nightly
|
||||
# tensorboard
|
||||
psutil==5.9.4
|
||||
@ -248,8 +242,6 @@ pyasn1==0.4.8
|
||||
# rsa
|
||||
pyasn1-modules==0.2.8
|
||||
# via google-auth
|
||||
pycln==2.1.2
|
||||
# via -r requirements-dev.in
|
||||
pycodestyle==2.10.0
|
||||
# via pylama
|
||||
pydeprecate==0.3.1
|
||||
@ -262,12 +254,10 @@ pyflakes==3.0.1
|
||||
# via pylama
|
||||
pylama==8.4.1
|
||||
# via -r requirements-dev.in
|
||||
pylint==2.15.6
|
||||
pylint==2.15.9
|
||||
# via -r requirements-dev.in
|
||||
pyparsing==3.0.9
|
||||
# via
|
||||
# matplotlib
|
||||
# packaging
|
||||
# via matplotlib
|
||||
pytest==7.2.0
|
||||
# via
|
||||
# -r requirements-dev.in
|
||||
@ -288,9 +278,7 @@ pyyaml==6.0
|
||||
# basicsr
|
||||
# gfpgan
|
||||
# huggingface-hub
|
||||
# libcst
|
||||
# omegaconf
|
||||
# pycln
|
||||
# pytorch-lightning
|
||||
# timm
|
||||
# transformers
|
||||
@ -307,6 +295,7 @@ requests==2.28.1
|
||||
# diffusers
|
||||
# fsspec
|
||||
# huggingface-hub
|
||||
# imaginAIry (setup.py)
|
||||
# requests-oauthlib
|
||||
# responses
|
||||
# tb-nightly
|
||||
@ -319,6 +308,8 @@ responses==0.22.0
|
||||
# via -r requirements-dev.in
|
||||
rsa==4.9
|
||||
# via google-auth
|
||||
ruff==0.0.206
|
||||
# via -r requirements-dev.in
|
||||
scikit-image==0.19.3
|
||||
# via basicsr
|
||||
scipy==1.9.3
|
||||
@ -329,14 +320,15 @@ scipy==1.9.3
|
||||
# gfpgan
|
||||
# scikit-image
|
||||
# torchdiffeq
|
||||
sentencepiece==0.1.97
|
||||
# via open-clip-torch
|
||||
six==1.16.0
|
||||
# via
|
||||
# google-auth
|
||||
# grpcio
|
||||
# python-dateutil
|
||||
snowballstemmer==2.2.0
|
||||
# via pydocstyle
|
||||
tb-nightly==2.12.0a20221125
|
||||
tb-nightly==2.12.0a20230101
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
@ -366,10 +358,8 @@ tomli==2.0.1
|
||||
# pylint
|
||||
# pytest
|
||||
tomlkit==0.11.6
|
||||
# via
|
||||
# pycln
|
||||
# pylint
|
||||
torch==1.13.0
|
||||
# via pylint
|
||||
torch==1.13.1
|
||||
# via
|
||||
# basicsr
|
||||
# facexlib
|
||||
@ -390,7 +380,7 @@ torchmetrics==0.6.0
|
||||
# via
|
||||
# imaginAIry (setup.py)
|
||||
# pytorch-lightning
|
||||
torchvision==0.14.0
|
||||
torchvision==0.14.1
|
||||
# via
|
||||
# basicsr
|
||||
# facexlib
|
||||
@ -412,20 +402,14 @@ tqdm==4.64.1
|
||||
# transformers
|
||||
transformers==4.19.2
|
||||
# via imaginAIry (setup.py)
|
||||
typer==0.7.0
|
||||
# via pycln
|
||||
types-toml==0.10.8.1
|
||||
# via responses
|
||||
typing-extensions==4.4.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# libcst
|
||||
# pytorch-lightning
|
||||
# torch
|
||||
# torchvision
|
||||
# typing-inspect
|
||||
typing-inspect==0.8.0
|
||||
# via libcst
|
||||
urllib3==1.26.13
|
||||
# via
|
||||
# requests
|
||||
@ -446,9 +430,9 @@ yapf==0.32.0
|
||||
# via
|
||||
# basicsr
|
||||
# gfpgan
|
||||
yarl==1.8.1
|
||||
yarl==1.8.2
|
||||
# via aiohttp
|
||||
zipp==3.10.0
|
||||
zipp==3.11.0
|
||||
# via importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
|
@ -51,7 +51,7 @@ def timed(description):
|
||||
def make_txts():
|
||||
src_json = f"{CURDIR}/../downloads/noodle-soup-prompts/nsp_pantry.json"
|
||||
dst_folder = f"{CURDIR}/../imaginairy/vendored/noodle_soup_prompts"
|
||||
with open(src_json, "r", encoding="utf-8") as f:
|
||||
with open(src_json, encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
categories = []
|
||||
for c in prompts.keys():
|
||||
@ -65,7 +65,7 @@ def make_txts():
|
||||
renamed_c = category_renames.get(c, c)
|
||||
with gzip.open(f"{dst_folder}/{renamed_c}.txt.gz", "wb") as f:
|
||||
for p in filtered_phrases:
|
||||
f.write(f"{p}\n".encode("utf-8"))
|
||||
f.write(f"{p}\n".encode())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
2
setup.py
2
setup.py
@ -1,6 +1,6 @@
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as f:
|
||||
with open("README.md", encoding="utf-8") as f:
|
||||
readme = f.read()
|
||||
|
||||
setup(
|
||||
|
@ -43,7 +43,7 @@ compare_prompts = [
|
||||
"model_version", ["SD-1.4", "SD-1.5", "SD-2.0", "SD-2.0-v", "SD-2.1", "SD-2.1-v"]
|
||||
)
|
||||
def test_model_versions(filename_base_for_orig_outputs, model_version):
|
||||
"""Test that we can switch between model versions"""
|
||||
"""Test that we can switch between model versions."""
|
||||
prompts = []
|
||||
for prompt_text in compare_prompts:
|
||||
prompts.append(
|
||||
@ -172,19 +172,19 @@ def test_img_to_img_fruit_2_gold_repeat():
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")
|
||||
run_count = 1
|
||||
|
||||
kwargs = dict(
|
||||
prompt="a white bowl filled with gold coins. sharp focus",
|
||||
prompt_strength=12,
|
||||
init_image=img,
|
||||
init_image_strength=0.2,
|
||||
mask_prompt="(fruit OR stem{*5} OR fruit stem)",
|
||||
mask_mode="replace",
|
||||
steps=20,
|
||||
seed=946188797,
|
||||
sampler_type="plms",
|
||||
fix_faces=True,
|
||||
upscale=True,
|
||||
)
|
||||
kwargs = {
|
||||
"prompt": "a white bowl filled with gold coins. sharp focus",
|
||||
"prompt_strength": 12,
|
||||
"init_image": img,
|
||||
"init_image_strength": 0.2,
|
||||
"mask_prompt": "(fruit OR stem{*5} OR fruit stem)",
|
||||
"mask_mode": "replace",
|
||||
"steps": 20,
|
||||
"seed": 946188797,
|
||||
"sampler_type": "plms",
|
||||
"fix_faces": True,
|
||||
"upscale": True,
|
||||
}
|
||||
prompts = [
|
||||
ImaginePrompt(**kwargs),
|
||||
ImaginePrompt(**kwargs),
|
||||
|
2
tox.ini
2
tox.ini
@ -13,7 +13,7 @@ linters = pylint,pycodestyle,pydocstyle,pyflakes,mypy
|
||||
ignore =
|
||||
Z999,C0103,C0301,C0302,C0114,C0115,C0116,
|
||||
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D413,D415,D417,
|
||||
Z999,E203,E501,E1101,
|
||||
Z999,E203,E501,E1101,E1131,
|
||||
Z999,R0901,R0902,R0903,R0904,R0193,R0912,R0913,R0914,R0915,R1702,
|
||||
Z999,W0221,W0511,W0612,W0613,W1203
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user