You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/modules/clip_embedders.py

135 lines
3.6 KiB
Python

"""Classes for text and image encoding"""
import kornia
import torch
from einops import repeat
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer
from imaginairy.utils import get_device
from imaginairy.vendored import clip
class FrozenCLIPEmbedder(nn.Module):
"""Uses the CLIP transformer encoder for text (from Hugging Face)."""
def __init__(
self,
version="openai/clip-vit-large-patch14",
device=get_device(),
max_length=77,
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(
self,
version="ViT-L/14",
device=get_device(),
max_length=77,
n_repeat=1,
normalize=True,
):
super().__init__()
self.model, _ = clip.load(version, jit=False, device=device)
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
self.normalize = normalize
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = clip.tokenize(text).to(self.device)
z = self.model.encode_text(tokens)
if self.normalize:
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
return z
def encode(self, text):
z = self(text)
if z.ndim == 2:
z = z[:, None, :]
z = repeat(z, "b 1 d -> b k d", k=self.n_repeat)
return z
class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model_name,
jit=False,
device=get_device(),
antialias=False,
):
super().__init__()
self.model, preprocess = clip.load(name=model_name, device=device, jit=jit)
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
return self.model.encode_image(self.preprocess(x))