You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/vendored/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py

86 lines
3.6 KiB
Python

from typing import cast
from jaxtyping import Float
from torch import Tensor, cat, device as Device, dtype as DType
import imaginairy.vendored.refiners.fluxion.layers as fl
from imaginairy.vendored.refiners.fluxion.adapters.adapter import Adapter
from imaginairy.vendored.refiners.fluxion.context import Contexts
from imaginairy.vendored.refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
from imaginairy.vendored.refiners.foundationals.clip.tokenizer import CLIPTokenizer
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
def __init__(
self,
target: CLIPTextEncoderG,
projection: fl.Linear | None = None,
) -> None:
with self.setup_adapter(target=target):
tokenizer = target.ensure_find(CLIPTokenizer)
super().__init__(
tokenizer,
fl.SetContext(
context="text_encoder_pooling", key="end_of_text_index", callback=self.set_end_of_text_index
),
target[1:-2],
fl.Parallel(
fl.Identity(),
fl.Chain(
target[-2:],
projection
or fl.Linear(
in_features=1280, out_features=1280, bias=False, device=target.device, dtype=target.dtype
),
fl.Lambda(func=self.pool),
),
),
)
def init_context(self) -> Contexts:
return {"text_encoder_pooling": {"end_of_text_index": []}}
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 1280"], Float[Tensor, "1 1280"]]:
return super().__call__(text)
@property
def tokenizer(self) -> CLIPTokenizer:
return self.ensure_find(CLIPTokenizer)
def set_end_of_text_index(self, end_of_text_index: list[int], tokens: Tensor) -> None:
position = (tokens == self.tokenizer.end_of_text_token_id).nonzero(as_tuple=True)[1].item()
end_of_text_index.append(cast(int, position))
def pool(self, x: Float[Tensor, "1 77 1280"]) -> Float[Tensor, "1 1280"]:
end_of_text_index = self.use_context(context_name="text_encoder_pooling").get("end_of_text_index", [])
assert len(end_of_text_index) == 1, "End of text index not found."
return x[:, end_of_text_index[0], :]
class DoubleTextEncoder(fl.Chain):
def __init__(
self,
text_encoder_l: CLIPTextEncoderL | None = None,
text_encoder_g: CLIPTextEncoderG | None = None,
projection: fl.Linear | None = None,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype)
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
super().__init__(
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
fl.Lambda(func=self.concatenate_embeddings),
)
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel)
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
return super().__call__(text)
def concatenate_embeddings(
self, text_embedding_l: Tensor, text_embedding_with_pooling: tuple[Tensor, Tensor]
) -> tuple[Tensor, Tensor]:
text_embedding_g, pooled_text_embedding = text_embedding_with_pooling
text_embedding = cat(tensors=[text_embedding_l, text_embedding_g], dim=-1)
return text_embedding, pooled_text_embedding