From 1b1a2d57404b3e1acab1f69a7c926d4d9472fb72 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 30 Oct 2023 12:29:54 -0400 Subject: [PATCH] Image Caption accepts bytes for images (#12561) Accept bytes for images in image caption --------- Co-authored-by: webcoderz <19884161+webcoderz@users.noreply.github.com> --- .../document_loaders/image_captions.py | 52 +++++++++++-------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/libs/langchain/langchain/document_loaders/image_captions.py b/libs/langchain/langchain/document_loaders/image_captions.py index 40d3de9379..a5179a685b 100644 --- a/libs/langchain/langchain/document_loaders/image_captions.py +++ b/libs/langchain/langchain/document_loaders/image_captions.py @@ -1,3 +1,4 @@ +from io import BytesIO from typing import Any, List, Tuple, Union import requests @@ -16,30 +17,28 @@ class ImageCaptionLoader(BaseLoader): def __init__( self, - path_images: Union[str, List[str]], + images: Union[str, bytes, List[Union[str, bytes]]], blip_processor: str = "Salesforce/blip-image-captioning-base", blip_model: str = "Salesforce/blip-image-captioning-base", ): - """ - Initialize with a list of image paths + """Initialize with a list of image data (bytes) or file paths Args: - path_images: A list of image paths. + images: Either a single image or a list of images. Accepts + image data (bytes) or file paths to images. blip_processor: The name of the pre-trained BLIP processor. blip_model: The name of the pre-trained BLIP model. """ - if isinstance(path_images, str): - self.image_paths = [path_images] + if isinstance(images, (str, bytes)): + self.images = [images] else: - self.image_paths = path_images + self.images = images self.blip_processor = blip_processor self.blip_model = blip_model def load(self) -> List[Document]: - """ - Load from a list of image files - """ + """Load from a list of image data or file paths""" try: from transformers import BlipForConditionalGeneration, BlipProcessor except ImportError: @@ -52,9 +51,9 @@ class ImageCaptionLoader(BaseLoader): model = BlipForConditionalGeneration.from_pretrained(self.blip_model) results = [] - for path_image in self.image_paths: + for image in self.images: caption, metadata = self._get_captions_and_metadata( - model=model, processor=processor, path_image=path_image + model=model, processor=processor, image=image ) doc = Document(page_content=caption, metadata=metadata) results.append(doc) @@ -62,11 +61,9 @@ class ImageCaptionLoader(BaseLoader): return results def _get_captions_and_metadata( - self, model: Any, processor: Any, path_image: str + self, model: Any, processor: Any, image: Union[str, bytes] ) -> Tuple[str, dict]: - """ - Helper function for getting the captions and metadata of an image - """ + """Helper function for getting the captions and metadata of an image.""" try: from PIL import Image except ImportError: @@ -74,20 +71,29 @@ class ImageCaptionLoader(BaseLoader): "`PIL` package not found, please install with `pip install pillow`" ) + image_source = image # Save the original source for later reference + try: - if path_image.startswith("http://") or path_image.startswith("https://"): - image = Image.open(requests.get(path_image, stream=True).raw).convert( - "RGB" - ) + if isinstance(image, bytes): + image = Image.open(BytesIO(image)).convert("RGB") + elif image.startswith("http://") or image.startswith("https://"): + image = Image.open(requests.get(image, stream=True).raw).convert("RGB") else: - image = Image.open(path_image).convert("RGB") + image = Image.open(image).convert("RGB") except Exception: - raise ValueError(f"Could not get image data for {path_image}") + if isinstance(image_source, bytes): + msg = "Could not get image data from bytes" + else: + msg = f"Could not get image data for {image_source}" + raise ValueError(msg) inputs = processor(image, "an image of", return_tensors="pt") output = model.generate(**inputs) caption: str = processor.decode(output[0]) - metadata: dict = {"image_path": path_image} + if isinstance(image_source, bytes): + metadata: dict = {"image_source": "Image bytes provided"} + else: + metadata = {"image_path": image_source} return caption, metadata