mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-17 09:25:47 +00:00
a46424c673
Kinda hacky copy/pasting from ddim. Need to cleanup
60 lines
1.6 KiB
Python
60 lines
1.6 KiB
Python
import os
|
|
import os.path
|
|
from functools import lru_cache
|
|
|
|
import torch
|
|
from torchvision import transforms
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
|
|
from imaginairy.utils import get_cached_url_path, get_device
|
|
from imaginairy.vendored.blip.blip import BLIP_Decoder, load_checkpoint
|
|
|
|
device = get_device()
|
|
if "mps" in device:
|
|
device = "cpu"
|
|
|
|
BLIP_EVAL_SIZE = 384
|
|
|
|
|
|
@lru_cache()
|
|
def blip_model():
|
|
from imaginairy import PKG_ROOT # noqa
|
|
|
|
config_path = os.path.join(
|
|
PKG_ROOT, "vendored", "blip", "configs", "med_config.json"
|
|
)
|
|
url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth"
|
|
|
|
model = BLIP_Decoder(image_size=BLIP_EVAL_SIZE, vit="base", med_config=config_path)
|
|
cached_url_path = get_cached_url_path(url)
|
|
model, msg = load_checkpoint(model, cached_url_path) # noqa
|
|
model.eval()
|
|
model = model.to(device)
|
|
return model
|
|
|
|
|
|
def generate_caption(image):
|
|
gpu_image = (
|
|
transforms.Compose(
|
|
[
|
|
transforms.Resize(
|
|
(BLIP_EVAL_SIZE, BLIP_EVAL_SIZE),
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
(0.48145466, 0.4578275, 0.40821073),
|
|
(0.26862954, 0.26130258, 0.27577711),
|
|
),
|
|
]
|
|
)(image)
|
|
.unsqueeze(0)
|
|
.to(device)
|
|
)
|
|
|
|
with torch.no_grad():
|
|
caption = blip_model().generate(
|
|
gpu_image, sample=False, num_beams=3, max_length=20, min_length=5
|
|
)
|
|
return caption[0]
|