imaginAIry/imaginairy/enhancers/describe_image_blip.py

60 lines
1.6 KiB
Python
Raw Normal View History

import os
import os.path
from functools import lru_cache
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from imaginairy.utils import get_cached_url_path, get_device
from imaginairy.vendored.blip.blip import BLIP_Decoder, load_checkpoint
device = get_device()
if "mps" in device:
device = "cpu"
BLIP_EVAL_SIZE = 384
@lru_cache()
def blip_model():
from imaginairy import PKG_ROOT # noqa
config_path = os.path.join(
PKG_ROOT, "vendored", "blip", "configs", "med_config.json"
)
url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth"
model = BLIP_Decoder(image_size=BLIP_EVAL_SIZE, vit="base", med_config=config_path)
cached_url_path = get_cached_url_path(url)
model, msg = load_checkpoint(model, cached_url_path) # noqa
model.eval()
model = model.to(device)
return model
def generate_caption(image):
gpu_image = (
transforms.Compose(
[
transforms.Resize(
(BLIP_EVAL_SIZE, BLIP_EVAL_SIZE),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)(image)
.unsqueeze(0)
.to(device)
)
with torch.no_grad():
caption = blip_model().generate(
gpu_image, sample=False, num_beams=3, max_length=20, min_length=5
)
return caption[0]