langchain/libs/experimental/langchain_experimental/open_clip/open_clip.py
Bagatur 8e0d5813c2
langchain[patch], experimental[patch]: replace langchain.schema imports (#15410)
Import from core instead.

Ran:
```bash
git grep -l 'from langchain.schema\.output_parser' | xargs -L 1 sed -i '' "s/from\ langchain\.schema\.output_parser/from\ langchain_core.output_parsers/g"
git grep -l 'from langchain.schema\.messages' | xargs -L 1 sed -i '' "s/from\ langchain\.schema\.messages/from\ langchain_core.messages/g"
git grep -l 'from langchain.schema\.document' | xargs -L 1 sed -i '' "s/from\ langchain\.schema\.document/from\ langchain_core.documents/g"
git grep -l 'from langchain.schema\.runnable' | xargs -L 1 sed -i '' "s/from\ langchain\.schema\.runnable/from\ langchain_core.runnables/g"
git grep -l 'from langchain.schema\.vectorstore' | xargs -L 1 sed -i '' "s/from\ langchain\.schema\.vectorstore/from\ langchain_core.vectorstores/g"
git grep -l 'from langchain.schema\.language_model' | xargs -L 1 sed -i '' "s/from\ langchain\.schema\.language_model/from\ langchain_core.language_models/g"
git grep -l 'from langchain.schema\.embeddings' | xargs -L 1 sed -i '' "s/from\ langchain\.schema\.embeddings/from\ langchain_core.embeddings/g"
git grep -l 'from langchain.schema\.storage' | xargs -L 1 sed -i '' "s/from\ langchain\.schema\.storage/from\ langchain_core.stores/g"
git checkout master libs/langchain/tests/unit_tests/schema/
make format
cd libs/experimental
make format
cd ../langchain
make format
```
2024-01-02 15:09:45 -05:00

90 lines
3.3 KiB
Python

from typing import Any, Dict, List
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain_core.embeddings import Embeddings
class OpenCLIPEmbeddings(BaseModel, Embeddings):
model: Any
preprocess: Any
tokenizer: Any
# Select model: https://github.com/mlfoundations/open_clip
model_name: str = "ViT-H-14"
checkpoint: str = "laion2b_s32b_b79k"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that open_clip and torch libraries are installed."""
try:
import open_clip
# Fall back to class defaults if not provided
model_name = values.get("model_name", cls.__fields__["model_name"].default)
checkpoint = values.get("checkpoint", cls.__fields__["checkpoint"].default)
# Load model
model, _, preprocess = open_clip.create_model_and_transforms(
model_name=model_name, pretrained=checkpoint
)
tokenizer = open_clip.get_tokenizer(model_name)
values["model"] = model
values["preprocess"] = preprocess
values["tokenizer"] = tokenizer
except ImportError:
raise ImportError(
"Please ensure both open_clip and torch libraries are installed. "
"pip install open_clip_torch torch"
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
text_features = []
for text in texts:
# Tokenize the text
tokenized_text = self.tokenizer(text)
# Encode the text to get the embeddings
embeddings_tensor = self.model.encode_text(tokenized_text)
# Normalize the embeddings
norm = embeddings_tensor.norm(p=2, dim=1, keepdim=True)
normalized_embeddings_tensor = embeddings_tensor.div(norm)
# Convert normalized tensor to list and add to the text_features list
embeddings_list = normalized_embeddings_tensor.squeeze(0).tolist()
text_features.append(embeddings_list)
return text_features
def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]
def embed_image(self, uris: List[str]) -> List[List[float]]:
try:
from PIL import Image as _PILImage
except ImportError:
raise ImportError("Please install the PIL library: pip install pillow")
# Open images directly as PIL images
pil_images = [_PILImage.open(uri) for uri in uris]
image_features = []
for pil_image in pil_images:
# Preprocess the image for the model
preprocessed_image = self.preprocess(pil_image).unsqueeze(0)
# Encode the image to get the embeddings
embeddings_tensor = self.model.encode_image(preprocessed_image)
# Normalize the embeddings tensor
norm = embeddings_tensor.norm(p=2, dim=1, keepdim=True)
normalized_embeddings_tensor = embeddings_tensor.div(norm)
# Convert tensor to list and add to the image_features list
embeddings_list = normalized_embeddings_tensor.squeeze(0).tolist()
image_features.append(embeddings_list)
return image_features