2023-03-18 12:26:19 +00:00
|
|
|
import torch
|
|
|
|
from PIL import Image
|
2023-04-03 07:43:34 +00:00
|
|
|
from transformers import BlipForConditionalGeneration, BlipProcessor
|
2023-03-20 08:27:20 +00:00
|
|
|
|
2023-03-23 07:33:45 +00:00
|
|
|
from core.prompts.file import IMAGE_PROMPT
|
2023-03-18 12:26:19 +00:00
|
|
|
|
2023-03-23 07:33:45 +00:00
|
|
|
from .base import BaseHandler
|
2023-03-18 12:26:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ImageCaptioning(BaseHandler):
|
|
|
|
def __init__(self, device):
|
|
|
|
print("Initializing ImageCaptioning to %s" % device)
|
|
|
|
self.device = device
|
|
|
|
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
|
|
|
|
self.processor = BlipProcessor.from_pretrained(
|
|
|
|
"Salesforce/blip-image-captioning-base"
|
|
|
|
)
|
|
|
|
self.model = BlipForConditionalGeneration.from_pretrained(
|
|
|
|
"Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype
|
|
|
|
).to(self.device)
|
|
|
|
|
|
|
|
def handle(self, filename: str):
|
|
|
|
img = Image.open(filename)
|
|
|
|
width, height = img.size
|
|
|
|
ratio = min(512 / width, 512 / height)
|
|
|
|
width_new, height_new = (round(width * ratio), round(height * ratio))
|
|
|
|
img = img.resize((width_new, height_new))
|
|
|
|
img = img.convert("RGB")
|
|
|
|
img.save(filename, "PNG")
|
|
|
|
print(f"Resize image form {width}x{height} to {width_new}x{height_new}")
|
|
|
|
|
|
|
|
inputs = self.processor(Image.open(filename), return_tensors="pt").to(
|
|
|
|
self.device, self.torch_dtype
|
|
|
|
)
|
|
|
|
out = self.model.generate(**inputs)
|
|
|
|
description = self.processor.decode(out[0], skip_special_tokens=True)
|
|
|
|
print(
|
|
|
|
f"\nProcessed ImageCaptioning, Input Image: {filename}, Output Text: {description}"
|
|
|
|
)
|
|
|
|
|
|
|
|
return IMAGE_PROMPT.format(filename=filename, description=description)
|