Add Chroma multimodal cookbook (#12952)

Pending:
* https://github.com/chroma-core/chroma/pull/1294
* https://github.com/chroma-core/chroma/pull/1293

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Lance Martin 2023-11-10 09:43:10 -08:00 committed by GitHub
parent 55912868da
commit d2e50b3108
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 835 additions and 118 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,3 @@
from .open_clip import OpenCLIPEmbeddings
__all__ = ["OpenCLIPEmbeddings"]

View File

@ -0,0 +1,87 @@
from typing import Any, Dict, List
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema.embeddings import Embeddings
class OpenCLIPEmbeddings(BaseModel, Embeddings):
model: Any
preprocess: Any
tokenizer: Any
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that open_clip and torch libraries are installed."""
try:
import open_clip
### Smaller, less performant
# model_name = "ViT-B-32"
# checkpoint = "laion2b_s34b_b79k"
### Larger, more performant
model_name = "ViT-g-14"
checkpoint = "laion2b_s34b_b88k"
model, _, preprocess = open_clip.create_model_and_transforms(
model_name=model_name, pretrained=checkpoint
)
tokenizer = open_clip.get_tokenizer(model_name)
values["model"] = model
values["preprocess"] = preprocess
values["tokenizer"] = tokenizer
except ImportError:
raise ImportError(
"Please ensure both open_clip and torch libraries are installed. "
"pip install open_clip_torch torch"
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
text_features = []
for text in texts:
# Tokenize the text
tokenized_text = self.tokenizer(text)
# Encode the text to get the embeddings
embeddings_tensor = self.model.encode_text(tokenized_text)
# Normalize the embeddings
norm = embeddings_tensor.norm(p=2, dim=1, keepdim=True)
normalized_embeddings_tensor = embeddings_tensor.div(norm)
# Convert normalized tensor to list and add to the text_features list
embeddings_list = normalized_embeddings_tensor.squeeze(0).tolist()
text_features.append(embeddings_list)
return text_features
def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]
def embed_image(self, uris: List[str]) -> List[List[float]]:
try:
from PIL import Image as _PILImage
except ImportError:
raise ImportError("Please install the PIL library: pip install pillow")
# Open images directly as PIL images
pil_images = [_PILImage.open(uri) for uri in uris]
image_features = []
for pil_image in pil_images:
# Preprocess the image for the model
preprocessed_image = self.preprocess(pil_image).unsqueeze(0)
# Encode the image to get the embeddings
embeddings_tensor = self.model.encode_image(preprocessed_image)
# Normalize the embeddings tensor
norm = embeddings_tensor.norm(p=2, dim=1, keepdim=True)
normalized_embeddings_tensor = embeddings_tensor.div(norm)
# Convert tensor to list and add to the image_features list
embeddings_list = normalized_embeddings_tensor.squeeze(0).tolist()
image_features.append(embeddings_list)
return image_features

View File

@ -54,7 +54,6 @@ from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
from langchain.embeddings.nlpcloud import NLPCloudEmbeddings
from langchain.embeddings.octoai_embeddings import OctoAIEmbeddings
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.embeddings.open_clip import OpenCLIPEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.sagemaker_endpoint import SagemakerEndpointEmbeddings
from langchain.embeddings.self_hosted import SelfHostedEmbeddings
@ -120,7 +119,6 @@ __all__ = [
"QianfanEmbeddingsEndpoint",
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"OpenCLIPEmbeddings",
]

View File

@ -1,56 +0,0 @@
from typing import Any, Dict, List
import numpy as np
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema.embeddings import Embeddings
class OpenCLIPEmbeddings(BaseModel, Embeddings):
model: Any
preprocess: Any
tokenizer: Any
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that open_clip and torch libraries are installed."""
try:
import open_clip
model_name = "ViT-B-32"
checkpoint = "laion2b_s34b_b79k"
model, _, preprocess = open_clip.create_model_and_transforms(
model_name=model_name, pretrained=checkpoint
)
tokenizer = open_clip.get_tokenizer(model_name)
values["model"] = model
values["preprocess"] = preprocess
values["tokenizer"] = tokenizer
except ImportError:
raise ImportError(
"Please ensure both open_clip and torch libraries are installed. "
"pip install open_clip_torch torch"
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
text_features = [
self.model.encode_text(self.tokenizer(text)).tolist() for text in texts
]
return text_features
def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]
def embed_image(self, images: List[np.ndarray]) -> List[List[float]]:
try:
from PIL import Image as _PILImage
except ImportError:
raise ImportError("Please install the PIL library: pip install pillow")
pil_images = [_PILImage.fromarray(image) for image in images]
image_features = [
self.model.encode_image(self.preprocess(pil_image).unsqueeze(0)).tolist()
for pil_image in pil_images
]
return image_features

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import base64
import logging
import uuid
from typing import (
@ -160,6 +161,94 @@ class Chroma(VectorStore):
**kwargs,
)
def encode_image(self, uri: str) -> str:
"""Get base64 string from image URI."""
with open(uri, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def add_images(
self,
uris: List[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more images through the embeddings and add to the vectorstore.
Args:
images (List[List[float]]): Images to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of IDs.
Returns:
List[str]: List of IDs of the added images.
"""
# Map from uris to b64 encoded strings
b64_texts = [self.encode_image(uri=uri) for uri in uris]
# Populate IDs
if ids is None:
ids = [str(uuid.uuid1()) for _ in uris]
embeddings = None
# Set embeddings
if self._embedding_function is not None and hasattr(
self._embedding_function, "embed_image"
):
embeddings = self._embedding_function.embed_image(uris=uris)
if metadatas:
# fill metadatas with empty dicts if somebody
# did not specify metadata for all images
length_diff = len(uris) - len(metadatas)
if length_diff:
metadatas = metadatas + [{}] * length_diff
empty_ids = []
non_empty_ids = []
for idx, m in enumerate(metadatas):
if m:
non_empty_ids.append(idx)
else:
empty_ids.append(idx)
if non_empty_ids:
metadatas = [metadatas[idx] for idx in non_empty_ids]
images_with_metadatas = [uris[idx] for idx in non_empty_ids]
embeddings_with_metadatas = (
[embeddings[idx] for idx in non_empty_ids] if embeddings else None
)
ids_with_metadata = [ids[idx] for idx in non_empty_ids]
try:
self._collection.upsert(
metadatas=metadatas,
embeddings=embeddings_with_metadatas,
documents=images_with_metadatas,
ids=ids_with_metadata,
)
except ValueError as e:
if "Expected metadata value to be" in str(e):
msg = (
"Try filtering complex metadata using "
"langchain.vectorstores.utils.filter_complex_metadata."
)
raise ValueError(e.args[0] + "\n\n" + msg)
else:
raise e
if empty_ids:
images_without_metadatas = [uris[j] for j in empty_ids]
embeddings_without_metadatas = (
[embeddings[j] for j in empty_ids] if embeddings else None
)
ids_without_metadatas = [ids[j] for j in empty_ids]
self._collection.upsert(
embeddings=embeddings_without_metadatas,
documents=images_without_metadatas,
ids=ids_without_metadatas,
)
else:
self._collection.upsert(
embeddings=embeddings,
documents=b64_texts,
ids=ids,
)
return ids
def add_texts(
self,
texts: Iterable[str],

View File

@ -49,7 +49,6 @@ EXPECTED_ALL = [
"QianfanEmbeddingsEndpoint",
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"OpenCLIPEmbeddings",
]