mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
Add Chroma multimodal cookbook (#12952)
Pending: * https://github.com/chroma-core/chroma/pull/1294 * https://github.com/chroma-core/chroma/pull/1293 --------- Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
55912868da
commit
d2e50b3108
476
cookbook/multi_modal_RAG_chroma.ipynb
Normal file
476
cookbook/multi_modal_RAG_chroma.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -0,0 +1,3 @@
|
||||
from .open_clip import OpenCLIPEmbeddings
|
||||
|
||||
__all__ = ["OpenCLIPEmbeddings"]
|
@ -0,0 +1,87 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class OpenCLIPEmbeddings(BaseModel, Embeddings):
|
||||
model: Any
|
||||
preprocess: Any
|
||||
tokenizer: Any
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that open_clip and torch libraries are installed."""
|
||||
try:
|
||||
import open_clip
|
||||
|
||||
### Smaller, less performant
|
||||
# model_name = "ViT-B-32"
|
||||
# checkpoint = "laion2b_s34b_b79k"
|
||||
### Larger, more performant
|
||||
model_name = "ViT-g-14"
|
||||
checkpoint = "laion2b_s34b_b88k"
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
model_name=model_name, pretrained=checkpoint
|
||||
)
|
||||
tokenizer = open_clip.get_tokenizer(model_name)
|
||||
values["model"] = model
|
||||
values["preprocess"] = preprocess
|
||||
values["tokenizer"] = tokenizer
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please ensure both open_clip and torch libraries are installed. "
|
||||
"pip install open_clip_torch torch"
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
text_features = []
|
||||
for text in texts:
|
||||
# Tokenize the text
|
||||
tokenized_text = self.tokenizer(text)
|
||||
|
||||
# Encode the text to get the embeddings
|
||||
embeddings_tensor = self.model.encode_text(tokenized_text)
|
||||
|
||||
# Normalize the embeddings
|
||||
norm = embeddings_tensor.norm(p=2, dim=1, keepdim=True)
|
||||
normalized_embeddings_tensor = embeddings_tensor.div(norm)
|
||||
|
||||
# Convert normalized tensor to list and add to the text_features list
|
||||
embeddings_list = normalized_embeddings_tensor.squeeze(0).tolist()
|
||||
text_features.append(embeddings_list)
|
||||
|
||||
return text_features
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
def embed_image(self, uris: List[str]) -> List[List[float]]:
|
||||
try:
|
||||
from PIL import Image as _PILImage
|
||||
except ImportError:
|
||||
raise ImportError("Please install the PIL library: pip install pillow")
|
||||
|
||||
# Open images directly as PIL images
|
||||
pil_images = [_PILImage.open(uri) for uri in uris]
|
||||
|
||||
image_features = []
|
||||
for pil_image in pil_images:
|
||||
# Preprocess the image for the model
|
||||
preprocessed_image = self.preprocess(pil_image).unsqueeze(0)
|
||||
|
||||
# Encode the image to get the embeddings
|
||||
embeddings_tensor = self.model.encode_image(preprocessed_image)
|
||||
|
||||
# Normalize the embeddings tensor
|
||||
norm = embeddings_tensor.norm(p=2, dim=1, keepdim=True)
|
||||
normalized_embeddings_tensor = embeddings_tensor.div(norm)
|
||||
|
||||
# Convert tensor to list and add to the image_features list
|
||||
embeddings_list = normalized_embeddings_tensor.squeeze(0).tolist()
|
||||
|
||||
image_features.append(embeddings_list)
|
||||
|
||||
return image_features
|
@ -54,7 +54,6 @@ from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
|
||||
from langchain.embeddings.nlpcloud import NLPCloudEmbeddings
|
||||
from langchain.embeddings.octoai_embeddings import OctoAIEmbeddings
|
||||
from langchain.embeddings.ollama import OllamaEmbeddings
|
||||
from langchain.embeddings.open_clip import OpenCLIPEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings.sagemaker_endpoint import SagemakerEndpointEmbeddings
|
||||
from langchain.embeddings.self_hosted import SelfHostedEmbeddings
|
||||
@ -120,7 +119,6 @@ __all__ = [
|
||||
"QianfanEmbeddingsEndpoint",
|
||||
"JohnSnowLabsEmbeddings",
|
||||
"VoyageEmbeddings",
|
||||
"OpenCLIPEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,56 +0,0 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
|
||||
|
||||
class OpenCLIPEmbeddings(BaseModel, Embeddings):
|
||||
model: Any
|
||||
preprocess: Any
|
||||
tokenizer: Any
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that open_clip and torch libraries are installed."""
|
||||
try:
|
||||
import open_clip
|
||||
|
||||
model_name = "ViT-B-32"
|
||||
checkpoint = "laion2b_s34b_b79k"
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
model_name=model_name, pretrained=checkpoint
|
||||
)
|
||||
tokenizer = open_clip.get_tokenizer(model_name)
|
||||
values["model"] = model
|
||||
values["preprocess"] = preprocess
|
||||
values["tokenizer"] = tokenizer
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please ensure both open_clip and torch libraries are installed. "
|
||||
"pip install open_clip_torch torch"
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
text_features = [
|
||||
self.model.encode_text(self.tokenizer(text)).tolist() for text in texts
|
||||
]
|
||||
return text_features
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
def embed_image(self, images: List[np.ndarray]) -> List[List[float]]:
|
||||
try:
|
||||
from PIL import Image as _PILImage
|
||||
except ImportError:
|
||||
raise ImportError("Please install the PIL library: pip install pillow")
|
||||
pil_images = [_PILImage.fromarray(image) for image in images]
|
||||
image_features = [
|
||||
self.model.encode_image(self.preprocess(pil_image).unsqueeze(0)).tolist()
|
||||
for pil_image in pil_images
|
||||
]
|
||||
return image_features
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
from typing import (
|
||||
@ -160,6 +161,94 @@ class Chroma(VectorStore):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def encode_image(self, uri: str) -> str:
|
||||
"""Get base64 string from image URI."""
|
||||
with open(uri, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
def add_images(
|
||||
self,
|
||||
uris: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more images through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
images (List[List[float]]): Images to add to the vectorstore.
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
ids (Optional[List[str]], optional): Optional list of IDs.
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added images.
|
||||
"""
|
||||
# Map from uris to b64 encoded strings
|
||||
b64_texts = [self.encode_image(uri=uri) for uri in uris]
|
||||
# Populate IDs
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid1()) for _ in uris]
|
||||
embeddings = None
|
||||
# Set embeddings
|
||||
if self._embedding_function is not None and hasattr(
|
||||
self._embedding_function, "embed_image"
|
||||
):
|
||||
embeddings = self._embedding_function.embed_image(uris=uris)
|
||||
if metadatas:
|
||||
# fill metadatas with empty dicts if somebody
|
||||
# did not specify metadata for all images
|
||||
length_diff = len(uris) - len(metadatas)
|
||||
if length_diff:
|
||||
metadatas = metadatas + [{}] * length_diff
|
||||
empty_ids = []
|
||||
non_empty_ids = []
|
||||
for idx, m in enumerate(metadatas):
|
||||
if m:
|
||||
non_empty_ids.append(idx)
|
||||
else:
|
||||
empty_ids.append(idx)
|
||||
if non_empty_ids:
|
||||
metadatas = [metadatas[idx] for idx in non_empty_ids]
|
||||
images_with_metadatas = [uris[idx] for idx in non_empty_ids]
|
||||
embeddings_with_metadatas = (
|
||||
[embeddings[idx] for idx in non_empty_ids] if embeddings else None
|
||||
)
|
||||
ids_with_metadata = [ids[idx] for idx in non_empty_ids]
|
||||
try:
|
||||
self._collection.upsert(
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings_with_metadatas,
|
||||
documents=images_with_metadatas,
|
||||
ids=ids_with_metadata,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "Expected metadata value to be" in str(e):
|
||||
msg = (
|
||||
"Try filtering complex metadata using "
|
||||
"langchain.vectorstores.utils.filter_complex_metadata."
|
||||
)
|
||||
raise ValueError(e.args[0] + "\n\n" + msg)
|
||||
else:
|
||||
raise e
|
||||
if empty_ids:
|
||||
images_without_metadatas = [uris[j] for j in empty_ids]
|
||||
embeddings_without_metadatas = (
|
||||
[embeddings[j] for j in empty_ids] if embeddings else None
|
||||
)
|
||||
ids_without_metadatas = [ids[j] for j in empty_ids]
|
||||
self._collection.upsert(
|
||||
embeddings=embeddings_without_metadatas,
|
||||
documents=images_without_metadatas,
|
||||
ids=ids_without_metadatas,
|
||||
)
|
||||
else:
|
||||
self._collection.upsert(
|
||||
embeddings=embeddings,
|
||||
documents=b64_texts,
|
||||
ids=ids,
|
||||
)
|
||||
return ids
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
@ -49,7 +49,6 @@ EXPECTED_ALL = [
|
||||
"QianfanEmbeddingsEndpoint",
|
||||
"JohnSnowLabsEmbeddings",
|
||||
"VoyageEmbeddings",
|
||||
"OpenCLIPEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user