From ea1ab391d421ff054def570109f3ca511536ef3d Mon Sep 17 00:00:00 2001 From: Lance Martin <122662504+rlancemartin@users.noreply.github.com> Date: Fri, 3 Nov 2023 13:33:36 -0700 Subject: [PATCH] Open Clip multimodal embeddings (#12754) --- .../text_embedding/open_clip.ipynb | 145 ++++++++++++++++++ .../langchain/embeddings/__init__.py | 2 + .../langchain/embeddings/open_clip.py | 56 +++++++ .../unit_tests/embeddings/test_imports.py | 1 + 4 files changed, 204 insertions(+) create mode 100644 docs/docs/integrations/text_embedding/open_clip.ipynb create mode 100644 libs/langchain/langchain/embeddings/open_clip.py diff --git a/docs/docs/integrations/text_embedding/open_clip.ipynb b/docs/docs/integrations/text_embedding/open_clip.ipynb new file mode 100644 index 0000000000..e00315ddbc --- /dev/null +++ b/docs/docs/integrations/text_embedding/open_clip.ipynb @@ -0,0 +1,145 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bb9b2af6-325f-4d1e-8e74-96ca5c2e27c5", + "metadata": {}, + "source": [ + "# OpenClip\n", + "\n", + "[OpenClip](https://github.com/mlfoundations/open_clip/tree/main) is an source implementation of OpenAI's CLIP.\n", + "\n", + "These multi-modal embeddings can be used to embed images or text.\n", + "\n", + "For text, use the same method `embed_documents` as with other embedding models.\n", + "\n", + "For images, use `embed_image` and simply pass a numpy array." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "570d818f-5705-4532-8e77-b0f335bb515d", + "metadata": {}, + "outputs": [], + "source": [ + "! pip install pillow open_clip_torch torch" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c4fc9af7-5659-4008-b6a3-0f99c68324aa", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import seaborn as sns\n", + "from PIL import Image as _PILImage\n", + "from langchain.embeddings import OpenCLIPEmbeddings\n", + "\n", + "# Images\n", + "img_path_dog='/Users/rlm/Desktop/Papers/LLaVA/dog.jpeg'\n", + "img_path_house='/Users/rlm/Desktop/Papers/LLaVA/house.jpeg'\n", + "\n", + "# Load images and convert to numpy arrays\n", + "image_np_dog = np.array(_PILImage.open(img_path_dog).convert(\"RGB\"))\n", + "image_np_house = np.array(_PILImage.open(img_path_house).convert(\"RGB\"))\n", + "\n", + "# Embe images or text\n", + "clip_embd = OpenCLIPEmbeddings()\n", + "img_feat_dog = clip_embd.embed_image([image_np_dog])\n", + "img_feat_house = clip_embd.embed_image([image_np_house])\n", + "text_feat_dog = clip_embd.embed_documents([\"dog\"])\n", + "text_feat_house = clip_embd.embed_documents([\"house\"])" + ] + }, + { + "cell_type": "markdown", + "id": "52b5e8d0-d1c7-475f-9d33-b8a1c492398a", + "metadata": {}, + "source": [ + "## Sanity check\n", + "\n", + "We can check simiarlity. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6d7962f3-8f05-463d-97b2-955eaa89d18d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(2, 2)\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAf8AAAGdCAYAAAAczXrvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA39UlEQVR4nO3deXxTVd7H8W8KNC2lLRQtZaeAYJFSEEGh7CAKgiyOMMouuMxTQKiiVERA1MywKyI4OAM4iugom4AFFFqQVXaqCFiqIIsglZYWKKW5zx8McZKCNJ20Ccnn/Xrdl+Tk5txfnieZX3/nnJxrMgzDEAAA8Bl+7g4AAAAUL5I/AAA+huQPAICPIfkDAOBjSP4AAPgYkj8AAD6G5A8AgI8h+QMA4GNI/gAA+JiS7g7gmrVpUe4OAfA4o6Y95e4QAI+0Z+bIIu3feqqOy/ryizjksr5cxWOSPwAAnsIqq8v68sQhdk+MCQAAFCEqfwAAHOQZrqv8PTHRemJMAAC4lVXefcNbkj8AAA5cOefviZjzBwDAx1D5AwDgIM9g2B8AAJ/i7XP+DPsDAOBjqPwBAHCQR+UPAIBvscpw2eEMi8WiJk2aKDg4WOHh4erevbsOHjxod86pU6fUr18/RUREKCgoSHfffbc+++wzp65D8gcAwEMkJycrLi5OW7du1dq1a5Wbm6uOHTsqOzvbdk7//v118OBBLV++XPv371fPnj3Vq1cv7d69u8DXYdgfAAAH7lrtn5iYaPd4/vz5Cg8P186dO9WqVStJ0ubNmzV79mw1bdpUkvTyyy9r+vTp2rlzpxo1alSg61D5AwDgwOrCIycnR5mZmXZHTk5OgeLIyMiQJIWFhdnamjdvro8//ljp6emyWq1atGiRLl26pDZt2hT4/ZH8AQAoQhaLRaGhoXaHxWK56eusVqtGjBih2NhY1a9f39b+ySefKDc3V+XLl5fZbNbTTz+tJUuWqHbt2gWOiWF/AAAcuHK1f0JCguLj4+3azGbzTV8XFxenlJQUff3113btY8eO1blz5/Tll1/qtttu09KlS9WrVy9t3LhR0dHRBYqJ5A8AgIM8F075m83mAiX7/zZ06FCtWLFCGzZsUJUqVWztqampevvtt5WSkqK77rpLkhQTE6ONGzdq1qxZmjNnToH6J/kDAODAXbf1MQxDw4YN05IlS5SUlKTIyEi75y9cuCBJ8vOzn7UvUaKErNaCR03yBwDAQ8TFxWnhwoVatmyZgoODderUKUlSaGioAgMDdeedd6p27dp6+umnNWXKFJUvX15Lly7V2rVrtWLFigJfh+QPAICDPJncct3Zs2dLUr6V+/PmzdPAgQNVqlQprVq1SqNHj1bXrl2VlZWl2rVra8GCBercuXOBr0PyBwDAgdVNu/saBdhf4I477nB6Rz9H/NQPAAAfQ+UPAIADdw37FxeSPwAADrw9+TPsDwCAj6HyBwDAgdXw7sqf5A8AgAOG/QEAgFeh8gcAwEGel9fGJH8AABww5w8AgI9hzh8AAHgVKn8AABzkGd5dG5P8AQBwYPXygXHvfncAACAfKn8AABx4+4I/kj8AAA68fc7fu98dAADIh8ofAAAHVob9AQDwLd6+va93vzsAAJAPlT8AAA68fcEfyR8AAAfevskPyR8AAAd5Xn5XP+/+0wYAAORD5Q8AgANvX+1P8gcAwIHVyxf8efe7AwAA+VD5AwDggGF/AAB8DKv9AQCAV6HyBwDAAZv8AADgY7x9e1/vfncAACAfKn8AABxY5d0L/kj+AAA48PZhf5I/AAAOvP13/t797gAAQD5U/gAAOLB6+SY/JH8AABww7A8AALwKlT8AAA68/Za+JH8AABzkefnv/J3+06Zdu3Y6d+5cvvbMzEy1a9fOFTEBAIAi5HTln5SUpMuXL+drv3TpkjZu3OiSoAAAcCeG/f9j3759tn9/9913OnXqlO1xXl6eEhMTVblyZddGBwCAG3j7sH+Bk3/Dhg1lMplkMpmuO7wfGBiomTNnujQ4AADgegUe10hLS1NqaqoMw9D27duVlpZmO44fP67MzEw98cQTRRkrAADFwmr4uexwhsViUZMmTRQcHKzw8HB1795dBw8ezHfeli1b1K5dOwUFBSkkJEStWrXSxYsXC3ydAlf+1atXlyRZrdYbnmMYhkwm7x4qAQB4P3fd2Cc5OVlxcXFq0qSJrly5opdeekkdO3bUd999p6CgIElXE/+DDz6ohIQEzZw5UyVLltTevXvl51fwmJ1e8Ddw4EDNmjXLFsQ1P/74o/r168eiPwDALc9dt/RNTEy0ezx//nyFh4dr586datWqlSRp5MiRGj58uEaPHm07r27duk5dx+k/bfbu3asGDRpoy5YttrYFCxYoJiZGt912m7PdAQDg1XJycpSZmWl35OTkFOi1GRkZkqSwsDBJ0unTp7Vt2zaFh4erefPmqlChglq3bq2vv/7aqZicTv7bt29Xz5491aZNG7300kvq1auXhg4dqilTpmjJkiXOdgcAgMfJM/xcdlgsFoWGhtodFovlpjFYrVaNGDFCsbGxql+/viTpyJEjkqTx48frySefVGJiou6++261b99ehw8fLvD7c3rYv1SpUpo8ebJKly6tiRMnqmTJkkpOTlazZs2c7QoAAI/kyrv6vZyQoPj4eLs2s9l809fFxcUpJSXFrqq/tu7u6aef1qBBgyRJjRo10ldffaV//vOfBfqjQipE5Z+bm6vnnntOf/vb35SQkKBmzZqpZ8+eWrVqlbNdAQDg9cxms0JCQuyOmyX/oUOHasWKFVq/fr2qVKlia69YsaIkqV69enbnR0VF6ejRowWOyenK/5577tGFCxeUlJSk++67T4ZhaNKkSerZs6eeeOIJvfPOO852CQCAR3HXLX0Nw9CwYcO0ZMkSJSUlKTIy0u75GjVqqFKlSvl+/nfo0CF16tSpwNcpVPJ/6623bKv9TSaTXnzxRXXs2FH9+vVztjsAADyOK4f9nREXF6eFCxdq2bJlCg4Otu2mGxoaqsDAQJlMJo0aNUrjxo1TTEyMGjZsqAULFuj777/Xp59+WuDrOJ38//GPf1y3vVGjRtq5c6ez3QEAgP+YPXu2JKlNmzZ27fPmzdPAgQMlSSNGjNClS5c0cuRIpaenKyYmRmvXrlWtWrUKfJ1C3dL3X//6l+bMmaO0tDRt2bJF1atX14wZMxQZGalu3boVpksAADyG1Y3D/gUxevRou9/5O8vpdzd79mzFx8erc+fOOnfunPLy8iRJZcuW1YwZMwodCAAAniLPMLns8EROJ/+ZM2dq7ty5GjNmjEqUKGFrv+eee7R//36XBgcAAFzP6WH/tLQ0NWrUKF+72WxWdna2S4ICAMCd3LXgr7g4XflHRkZqz549+doTExMVFRXlipgAAHArd93Vr7g4XfnHx8crLi5Oly5dst3e96OPPpLFYtF7771XFDECAFCs8tx0Y5/i4nTyHzJkiAIDA/Xyyy/rwoULevzxx1WpUiW9+eab+vOf/1wUMQIAABcq1E/9+vTpoz59+ujChQvKyspSeHi4q+MCAMBtvH3O3+nk365dOy1evFhly5ZV6dKlVbp0aUlSZmamunfvrnXr1rk8SLjW6kWG9m4y9MvPUil/qWY9qdsTJlWoevXDfvaUoXEDr/9b0ydeMunuVt79pYDveuL+JmofU1s1KoQpJ/eK9qad0IxlX+un07/ZzvEvWULP9WilBxrXlX/JEtp84Ce98ck6pZ+/4MbI4WqeOlfvKk4n/6SkJF2+fDlf+6VLl7Rx40aXBIWi9cN+Q626mlS9jpRnlT6fZ+jtMYZe/rtkDjCp3O3SGwvtE/ymL6QvPzV0VxM3BQ0Ug8a1q+jjjXv17U+/qEQJk4Z1jdXsuJ7q+foCXbp8RZL0fM/WanlXpEb9c6WyLuZo9KNtNW1IVw2c/rGbowcKrsDJf9++fbZ/f/fdd7b9hiUpLy9PiYmJqly5smujQ5GIe93+L9q+z0kJfzZ07LBUO1ryK2FSSJj9a/ZuturulpI5kKof3itu9hK7x698sEbrLc+oXtUK2pV6XGUC/NWjWX0lLPhC3xw6Jkka9+EaLX15oKJrRGj/j6eu1y1uQVYW/F3VsGFDmUwmmUwmtWvXLt/zgYGBmjlzpkuDQ/G49J/RytLB13/+6GFDP6dKveK8+8sAOCoT4C9JyrhwSZIUVa2CSpUsoW0Hf7916o+//KYT6ZmKiaxI8vcinrozn6sUOPmnpaXJMAzVrFlT27dv1+233257zt/fX+Hh4XY7/uHWYLUa+nSOoZr1pEo1rv9h37LaUEQ1qWY97/4yAP/NZJJGPdJGu1OPK/XkWUnSbcGldTn3is5fzLE7N/38BZUPDnJHmEChFDj5V69eXZJktVr/54vm5OQoJ8f+y3M5xyp/s3cvsPBEn8wydPJHaeTU6yf2yzmGdqyXHnycxA/fkvBoO9WuWF4DZ3zi7lDgBt6+4M8t785isSg0NNTuWDT7rDtC8WmfzLIqZZs0fJJJ5W6/fnLfs1G6nCM1bV/MwQFuNPrRtmpVv6aGzPxUp89l2dp/PX9B/qVKKjjQbHd+WHBpnT3P9ubexGqYXHZ4Irck/4SEBGVkZNgdf/5LeXeE4pMMw9Ans6zau1ka/jeTbou48Ydz82pD0fdJwWU98wMMuNroR9uqXYPaemrmpzpxNtPuuQNHf1HulTw1rVPV1lY9vJwqhYVob9rJ4g4VKLRCbfLzvzKbzTKb7f9y9j/r3UMsnuSTWVeH8p8aZ1JAoJSZfvU3/QFBkr/59yR/5oSh1BTpLxNJ/PANL/Vqp06N62rE3OXKvnRZ5YOv7mOSdSlHObl5yrp0WUu2pOi5nq2VceGSsi9d1ug/tdXeIydY7OdlWO0Pr7NxxdX/vvmC/UY+feNNuq/j74+3rDZU9jbpzruLMTjAjXq1jJEk/ePZXnbtr3ywWsu3fSdJmrI4WYZhaOrgrlc3+fn+R73xMZubeRtPHa53FZNhGNffyq2YrU3jjoCAo1HTnnJ3CIBH2jNzZJH233vLMy7r6+Nmc1zWl6s4XfmXK1dOJlP+v4hMJpMCAgJUu3ZtDRw4UIMGDXJJgAAAwLWcTv6vvPKKXn/9dXXq1ElNmzaVJG3fvl2JiYmKi4tTWlqa/vKXv+jKlSt68sknXR4wAABFzduH/Z1O/l9//bVee+01PfOM/ZDIu+++qzVr1uizzz5TgwYN9NZbb5H8AQC3JG9f8Of0EvvVq1erQ4cO+drbt2+v1atXS5I6d+6sI0eO/O/RAQAAl3M6+YeFhenzzz/P1/75558rLOzq3WCys7MVHHyDjeIBAPBw3r7Jj9PD/mPHjtVf/vIXrV+/3jbn/80332jVqlWaM+fqisa1a9eqdevWro0UAIBi4qlJ21WcTv5PPvmk6tWrp7fffluLFy+WJNWtW1fJyclq3ry5JOm5555zbZQAAMBlCrXJT2xsrGJjY10dCwAAHoHK/zry8vK0dOlSHThwQJJ011136eGHH+aWvgAAr0Dyd/DDDz+oc+fOOn78uOrWrSvp6l36qlatqpUrV6pWrVouDxIAALiO06v9hw8frlq1aunYsWPatWuXdu3apaNHjyoyMlLDhw8vihgBAChWVplcdngipyv/5ORkbd261fazPkkqX768/vrXv7IOAADgFRj2d2A2m3X+/Pl87VlZWfL393dJUAAAuJO3J3+nh/27dOmip556Stu2bZNhGDIMQ1u3btUzzzyjhx9+uChiBAAALuR08n/rrbdUq1YtNWvWTAEBAQoICFBsbKxq166tN998syhiBACgWLHDn4OyZctq2bJlOnz4sL7//ntJUlRUlGrXru3y4AAAcAdPTdquUqjf+UvSHXfcoTvuuMOVsQAAgGJQoOQfHx9f4A6nTZtW6GAAAPAEBpW/tHv3brvHu3bt0pUrV2yb/Bw6dEglSpRQ48aNXR8hAADFzFN/n+8qBUr+69evt/172rRpCg4O1oIFC1SuXDlJ0m+//aZBgwapZcuWRRMlAABwGafn/KdOnao1a9bYEr8klStXTq+99po6duzIHf0AALc8Fvw5yMzM1JkzZ/K1nzlz5rqb/wAAcKvx9jl/p3/n36NHDw0aNEiLFy/Wzz//rJ9//lmfffaZBg8erJ49exZFjAAAwIWcrvznzJmj559/Xo8//rhyc3OvdlKypAYPHqzJkye7PEAAAIobw/4OSpcurXfeeUeTJ09WamqqJKlWrVoKCgpyeXAAALiDtw/7F3qTn6CgIDVo0MCVsQAA4BG8vfJ3es4fAADc2gpd+QMA4K0Mw90RFC0qfwAAHFhlctnhDIvFoiZNmig4OFjh4eHq3r27Dh48eN1zDcNQp06dZDKZtHTpUqeuQ/IHAMBDJCcnKy4uTlu3btXatWuVm5urjh07Kjs7O9+5M2bMkMlUuLUJDPsDAODAXav9ExMT7R7Pnz9f4eHh2rlzp1q1amVr37Nnj6ZOnaodO3aoYsWKTl+H5A8AgANPWe2fkZEhSQoLC7O1XbhwQY8//rhmzZqliIiIQvVL8gcAoAjl5OQoJyfHrs1sNstsNv/h66xWq0aMGKHY2FjVr1/f1j5y5Eg1b95c3bp1K3RMzPkDAODAMFx3WCwWhYaG2h0Wi+WmMcTFxSklJUWLFi2ytS1fvlzr1q3TjBkz/qf3R+UPAIADV875JyQkKD4+3q7tZlX/0KFDtWLFCm3YsEFVqlSxta9bt06pqakqW7as3fmPPPKIWrZsqaSkpALFRPIHAKAIFWSI/xrDMDRs2DAtWbJESUlJioyMtHt+9OjRGjJkiF1bdHS0pk+frq5duxY4JpI/AAAO3LXaPy4uTgsXLtSyZcsUHBysU6dOSZJCQ0MVGBioiIiI6y7yq1atWr4/FP4IyR8AAAfuWu0/e/ZsSVKbNm3s2ufNm6eBAwe67DokfwAAHLhre1+jEBcuzGtY7Q8AgI+h8gcAwIG75vyLC8kfAAAH3p78GfYHAMDHUPkDAODATev9ig3JHwAABwz7AwAAr0LlDwCAIy8f9yf5AwDgwNuH/Un+AAA4cNcOf8WFOX8AAHwMlT8AAA4Y9gcAwNd4efJn2B8AAB9D5Q8AgANvX/BH8gcAwJGXJ3+G/QEA8DFU/gAAOGC1PwAAvoZhfwAA4E2o/AEAcMCwPwAAvsbLh/1J/gAA5OPdlT9z/gAA+BgqfwAAHDHsDwCAj/Hy5M+wPwAAPobKHwAAR/zUDwAA3+Ltd/Vj2B8AAB9D5Q8AgCMvr/xJ/gAAOPLyOX+G/QEA8DFU/gAAODAx7A8AgI8h+QMA4GOY8wcAAN6Eyh8AAEcM+wMA4GO8PPkz7A8AgI+h8gcAwJGXV/4kfwAAHLHaHwAAeBMqfwAAHLDDHwAAvsbLkz/D/gAA+BiSPwAAHsJisahJkyYKDg5WeHi4unfvroMHD9qeT09P17Bhw1S3bl0FBgaqWrVqGj58uDIyMpy6DskfAAAHJsN1hzOSk5MVFxenrVu3au3atcrNzVXHjh2VnZ0tSTpx4oROnDihKVOmKCUlRfPnz1diYqIGDx7s5PszDI+Y2bjf71F3hwB4nNUn9ro7BMAj+UUcKtL+a745zWV9HXk2vtCvPXPmjMLDw5WcnKxWrVpd95x///vf6tu3r7Kzs1WyZMGW8lH5AwDgoa4N54eFhf3hOSEhIQVO/BKr/QEAyM+FY+I5OTnKycmxazObzTKbzX/4OqvVqhEjRig2Nlb169e/7jm//vqrJk6cqKeeesqpmKj8AQBwZLjusFgsCg0NtTssFstNQ4iLi1NKSooWLVp03eczMzP10EMPqV69eho/frxTb4/KHwCAIpSQkKD4ePt5/5tV/UOHDtWKFSu0YcMGValSJd/z58+f14MPPqjg4GAtWbJEpUqVciomkj8AAA5cucNfQYb4rzEMQ8OGDdOSJUuUlJSkyMjIfOdkZmbqgQcekNls1vLlyxUQEOB0TCR/AAAcuel3cHFxcVq4cKGWLVum4OBgnTp1SpIUGhqqwMBAZWZmqmPHjrpw4YI++OADZWZmKjMzU5J0++23q0SJEgW6DskfAAAPMXv2bElSmzZt7NrnzZungQMHateuXdq2bZskqXbt2nbnpKWlqUaNGgW6DskfAABHbqr8b7b1Tps2bW56TkGQ/AEAcODtd/Xjp34AAPgYKn8AABwZJndHUKRI/gAAOPLyYX+SPwAADpjzBwAAXoXKHwAAR15e+ZP8AQBwwLA/AADwKlT+AAA48vLKn+QPAIAjL0/+DPsDAOBjqPwBAHDAgj8AAOBVSP4AAPgYhv0BAHDk5cP+JH8AABx4+5w/yR8AAEdenvyZ8wcAwMdQ+QMA4MjLK3+SPwAADrx9zp9hfwAAfAyVPwAAjry88if5AwDggGF/AADgVaj8AQBw5OWVP8kfAABHXp78GfYHAMDHUPkDAODA2xf8kfwBAHBE8gcAwMd4efJnzh8AAB9D5Q8AgAPm/AEA8DVenvwZ9gcAwMdQ+QMA4IBhfwAAfI2XJ3+G/QEA8DFU/gAAOPLyyp/kDwCAA5O7AyhiDPsDAOBjqPwBAHDEsD8AAL6Fn/oBAOBrvDz5M+cPAICPofIHAMARlT8AAL7FZLjucIbFYlGTJk0UHBys8PBwde/eXQcPHrQ759KlS4qLi1P58uVVpkwZPfLII/rll1+cug7JHwAAD5GcnKy4uDht3bpVa9euVW5urjp27Kjs7GzbOSNHjtTnn3+uf//730pOTtaJEyfUs2dPp67j9LD/0aNHVbVqVZlM9lsgGIahY8eOqVq1as52CQCAZ3HTsH9iYqLd4/nz5ys8PFw7d+5Uq1atlJGRoX/84x9auHCh2rVrJ0maN2+eoqKitHXrVt13330Fuo7TlX9kZKTOnDmTrz09PV2RkZHOdgcAgMdx17C/o4yMDElSWFiYJGnnzp3Kzc1Vhw4dbOfceeedqlatmrZs2VLgfp2u/A3DyFf1S1JWVpYCAgKc7Q4AAK+Wk5OjnJwcuzaz2Syz2fyHr7NarRoxYoRiY2NVv359SdKpU6fk7++vsmXL2p1boUIFnTp1qsAxFTj5x8fHS5JMJpPGjh2r0qVL257Ly8vTtm3b1LBhwwJfGAAAj+XCYX+LxaIJEybYtY0bN07jx4//w9fFxcUpJSVFX3/9teuC+Y8CJ//du3dLulr579+/X/7+/rbn/P39FRMTo+eff97lAQIAUNxcucNfQkKCrYC+5mZV/9ChQ7VixQpt2LBBVapUsbVHRETo8uXLOnfunF31/8svvygiIqLAMRU4+a9fv16SNGjQIL355psKCQkp8EUAAPBVBRniv8YwDA0bNkxLlixRUlJSvrV0jRs3VqlSpfTVV1/pkUcekSQdPHhQR48eVbNmzQock9Nz/pMmTbph4t+/f7+io6Od7RIAAM/iptX+cXFxWrhwoZYtW6bg4GDbPH5oaKgCAwMVGhqqwYMHKz4+XmFhYQoJCdGwYcPUrFmzAq/0lwqx2j86OlorV67M1z5lyhQ1bdrU2e4AAPA8hgsPJ8yePVsZGRlq06aNKlasaDs+/vhj2znTp09Xly5d9Mgjj6hVq1aKiIjQ4sWLnbqO05V/fHy8HnnkEQ0aNEjTpk1Tenq6+vfvr/3792vhwoXOdgcAgMdx1139DOPmFw4ICNCsWbM0a9asQl/H6cr/hRde0JYtW7Rx40Y1aNBADRo0kNls1r59+9SjR49CBwIAAIpHobb3rV27turXr68ff/xRmZmZ6t27t1OrDAEA8GhuGvYvLk4n/02bNqlBgwY6fPiw9u3bp9mzZ2vYsGHq3bu3fvvtt6KIEQCAYmUyDJcdnsjp5N+uXTv17t1bW7duVVRUlIYMGaLdu3fr6NGjrPQHAOAW4PSCvzVr1qh169Z2bbVq1dKmTZv0+uuvuywwAADcxjMLdpdxuvK/lvh/+OEHrV69WhcvXpT0+7a/AADc6jzlxj5Fxenkf/bsWbVv31516tRR586ddfLkSUnS4MGD2d4XAIBbgNPJf+TIkSpVqpSOHj1qd3Of3r1764svvnBpcAAAuIWXr/Yv1Jz/6tWr7W40IEl33HGHfvrpJ5cFBgCAu3jqcL2rOF35Z2dn21X816Snpxf4xgUAAMB9nE7+LVu21Pvvv297bDKZZLVaNWnSJLVt29alwQEA4BYM+9ubNGmS2rdvrx07dujy5ct64YUX9O233yo9PV2bNm0qihgBAChWDPs7qF+/vg4dOqQWLVqoW7duys7OVs+ePbV7927VqlWrKGIEAKB4UfnbO3r0qKpWraoxY8Zc97lq1aq5JDAAAFA0nK78IyMjdebMmXztZ8+eVWRkpEuCAgDAnbx9kx+nK3/DMGQymfK1Z2VlKSAgwCVBAQDgVh56Qx5XKXDyj4+Pl/T7Nr7//XO/vLw8bdu2TQ0bNnR5gAAAwLUKnPx3794t6Wrlv3//fvn7+9ue8/f3V0xMDNv7AgC8gqcO17tKgZP/+vXrJUmDBg3Sm2++qZCQkCILCgAAtyL525s3b15RxAEAAIqJ08kfAABvZ7K6O4KiRfL3UdEto/To8w+rTuOaKl8pTON6TNLmZd/YnVPtzsoa8te+atC6nvxK+unodz9rwp+m6syxX90UNVC0/v6BtHaDdOSoFGCWGtWXnntaivyv7UuOHpcmvSPt2i9dzpVaNpXGPCvdFua+uFEEvHzY3+nf+cM7BASZdWTfT5o59B/Xfb5izQqavnGijh48rufajtPTMc/rw9c+U+6ly8UcKVB8vtkrPd5DWjRb+sdUKfeKNPh56cLFq89fuCgNeV4ymaT506WFb1895/8SJKuXV4rwLlT+PuqbxD36JnHPDZ8f9Npj2r5qt9578QNb28kjvxRDZID7zJ1s/9iSIMV2M+nbQ4aaxEi7U6Tjp6TF70llgn4/594u0tZdUvN7ij9mFA1vX+1fqMr/X//6l2JjY1WpUiX99NNPkqQZM2Zo2bJlLg0O7mEymXTvQ3fr58MnZPlijD459Z7e2vKGmndr4u7QgGJ1Puvqf0ODr/738uWrVb9/qd/PMftLfn5XpwHgRQzDdYcHcjr5z549W/Hx8ercubPOnTunvLw8SVLZsmU1Y8YMV8cHNygbHqrSwYHq/WJ3fbN6jxIeeE2blm7XuM+eV4NW9dwdHlAsrFbJ8rZ0d7ShOjWvtsXcJQUGSFPelS5eujoNMOkdKS/PpDNn3RsvXMvbt/d1OvnPnDlTc+fO1ZgxY1SiRAlb+z333KP9+wv2p29OTo4yMzPtDquR52woKCJ+fle3b96ybIcWz1ip1L0/6uO/LdW2FbvU5en73RwdUDxenS4dTpOmvvJ7W1hZacYEKWmz1PhBqelDUmaWVK+Ooevseg54LKfn/NPS0tSoUaN87WazWdnZ2QXqw2KxaMKECXZtkYpSLd3lbDgoAhm/nteV3Cv66cAxu/aj3/+s+rF3uikqoPhMnCElb5H+NVOKCLd/LraJtOYj6bdzUokSUkiw1LKHVLWSOyJFkfHQit1VCnVXvz179uRrT0xMVFRUVIH6SEhIUEZGht0RKZKKp7iSe0UHv0lV1TqV7dor31FJv/zEz/zgvQzjauL/cqM0b4ZUpeKNzy1X9mri37pLOvub1C62mIJEsfD2YX+nK//4+HjFxcXp0qVLMgxD27dv10cffSSLxaL33nuvQH2YzWaZzWa7Nj9TiRucjaIQEBSgyrUjbI8jIsNVK6aGMtOzdObYr/r3lOUas2ik9m38TnvXf6smDzZUs66N9Vzb8e4LGihir06XVn4lvf26FBQo2zx+cJmrv/uXpMWrpJrVr04B7PlWemOmNOBR+70AAE9nMgznlyJ++OGHGj9+vFJTUyVJlSpV0oQJEzR48OBCB3K/36OFfi2c16B1PU1dPyFf+5r5SZr8xCxJ0gOD2uqx0T10W5Xy+vngCS0Y/7G2LN9R3KH6tNUn9ro7BJ8S1fr6E/dvjDbUo9PVf099V1qaKGVkSpUipD8/LA3oJeb8i5lfxKEi7b9l98k3P6mANi4d5bK+XKVQyf+aCxcuKCsrS+Hh4Tc/+SZI/kB+JH/g+oo6+bfq5rrkv2GZ5yV/p+f8L168qAsXLkiSSpcurYsXL2rGjBlas2aNy4MDAACu53Ty79atm95//31J0rlz59S0aVNNnTpV3bp10+zZs10eIAAAxc5w4eGBnE7+u3btUsuWLSVJn376qSIiIvTTTz/p/fff11tvveXyAAEAKG7evtrf6eR/4cIFBQdf3etyzZo16tmzp/z8/HTffffZtvoFAACey+nkX7t2bS1dulTHjh3T6tWr1bFjR0nS6dOnFRIS4vIAAQAodlbDdYcHcjr5v/LKK3r++edVo0YN3XvvvWrWrJmkq6MA19v5DwCAW46Xz/k7vcnPn/70J7Vo0UInT55UTEyMrb19+/bq0aOHS4MDAMAdPHWu3lWcTv6SFBERoYiICLu2pk2buiQgAABQtJxO/m3btpXpD7ayWrdu3f8UEAAAblf4/e9uCU4n/4YNG9o9zs3N1Z49e5SSkqIBAwa4Ki4AANyGYX8H06dPv277+PHjlZWV9T8HBAAAipbTq/1vpG/fvvrnP//pqu4AAHAfVvsXzJYtWxQQEOCq7gAAcBsTc/72evbsaffYMAydPHlSO3bs0NixY10WGAAAKBpOD/uHhobaHWFhYWrTpo1WrVqlcePGFUWMAAAUL6sLDyds2LBBXbt2VaVKlWQymbR06VK757OysjR06FBVqVJFgYGBqlevnubMmeP023O68p83b57TFwEA4FbirmH/7OxsxcTE6Iknnsg30i5J8fHxWrdunT744APVqFFDa9as0f/93/+pUqVKevjhhwt8nULP+e/cuVMHDhyQJN11111s7QsAwP+oU6dO6tSp0w2f37x5swYMGKA2bdpIkp566im9++672r59u1PJ3+lh/9OnT6tdu3Zq0qSJhg8fruHDh6tx48Zq3769zpw542x3AAB4Hg9d7d+8eXMtX75cx48fl2EYWr9+vQ4dOmS7yV5BOZ38hw0bpvPnz+vbb79Venq60tPTlZKSoszMTA0fPtzZ7gAA8DyG4bIjJydHmZmZdkdOTk6hwpo5c6bq1aunKlWqyN/fXw8++KBmzZqlVq1aOdWP08k/MTFR77zzjqKiomxt9erV06xZs/TFF1842x0AAB7HZLjusFgs+RbLWyyWQsU1c+ZMbd26VcuXL9fOnTs1depUxcXF6csvv3SqH6fn/K1Wq0qVKpWvvVSpUrJanVzWCACAl0tISFB8fLxdm9lsdrqfixcv6qWXXtKSJUv00EMPSZIaNGigPXv2aMqUKerQoUOB+3K68m/Xrp2effZZnThxwtZ2/PhxjRw5Uu3bt3e2OwAAPI8Lh/3NZrNCQkLsjsIk/9zcXOXm5srPzz51lyhRwuni2+nK/+2339bDDz+sGjVqqGrVqpKkY8eOqX79+vrggw+c7Q4AAI9jctNAdlZWln744Qfb47S0NO3Zs0dhYWGqVq2aWrdurVGjRikwMFDVq1dXcnKy3n//fU2bNs2p6zid/KtWrapdu3bpyy+/1Pfffy9JioqKcmq4AQAA5Ldjxw61bdvW9vjadMGAAQM0f/58LVq0SAkJCerTp4/S09NVvXp1vf7663rmmWecuo7JMDxjA+P7/R51dwiAx1l9Yq+7QwA8kl/EoSLt//7Y11zW19pNL7usL1cp1CY/X331lb766iudPn063zwDd/YDANzyPKIsLjpOJ/8JEybo1Vdf1T333KOKFSvKZDIVRVwAAKCIOJ3858yZo/nz56tfv35FEQ8AAG7HLX0dXL58Wc2bNy+KWAAA8Axenvyd/p3/kCFDtHDhwqKIBQAAFIMCVf7/vTOR1WrV3//+d3355Zdq0KBBvt3+nP2tIQAAHsfLN6wtUPLfvXu33eOGDRtKklJSUuzaWfwHAPAGzPlLWr9+fVHHAQCA5/Dy5O/0nD8AALi1FWqTHwAAvJqXV/4kfwAAHHn5gj+G/QEA8DFU/gAAOGC1PwAAvsbLkz/D/gAA+BgqfwAAHHl55U/yBwDAkZcnf4b9AQDwMVT+AAA48vLf+ZP8AQBwwE/9AADwNV6e/JnzBwDAx1D5AwDgyOrdlT/JHwAARwz7AwAAb0LlDwCAIy+v/En+AAA48vLkz7A/AAA+hsofAABHrPYHAMDHGN69vy/D/gAA+BgqfwAAHHn5gj+SPwAAjpjzBwDAx3h55c+cPwAAPobKHwAAR15e+ZP8AQBw5OXJn2F/AAB8DJU/AACOrN69yQ/JHwAARwz7AwAAb0LlDwCAIy+v/En+AAA48vId/hj2BwDAx1D5AwDgwPDyW/qS/AEAcMSwPwAAPsYwXHc4YcOGDeratasqVaokk8mkpUuX5jvnwIEDevjhhxUaGqqgoCA1adJER48edeo6JH8AADxEdna2YmJiNGvWrOs+n5qaqhYtWujOO+9UUlKS9u3bp7FjxyogIMCp6zDsDwCAIzft8NepUyd16tTphs+PGTNGnTt31qRJk2xttWrVcvo6VP4AADhy07D/H7FarVq5cqXq1KmjBx54QOHh4br33nuvOzVwMyR/AACKUE5OjjIzM+2OnJwcp/s5ffq0srKy9Ne//lUPPvig1qxZox49eqhnz55KTk52qi+SPwAADgyr1WWHxWJRaGio3WGxWJyOyfqfqYhu3bpp5MiRatiwoUaPHq0uXbpozpw5TvXFnD8AAI5cOFyfkJCg+Ph4uzaz2ex0P7fddptKliypevXq2bVHRUXp66+/dqovkj8AAEXIbDYXKtk78vf3V5MmTXTw4EG79kOHDql69epO9UXyBwDAkZs2+cnKytIPP/xge5yWlqY9e/YoLCxM1apV06hRo9S7d2+1atVKbdu2VWJioj7//HMlJSU5dR2SPwAAjty0ve+OHTvUtm1b2+Nr0wUDBgzQ/Pnz1aNHD82ZM0cWi0XDhw9X3bp19dlnn6lFixZOXcdkGJ5x38L7/R51dwiAx1l9Yq+7QwA8kl/EoSLt/4GAPi7ra/WlD13Wl6tQ+QMA4MDw8r39Sf4AADjirn4AAPgWb6/82eQHAAAfQ+UPAIAjLx/295jV/vAMOTk5slgsSkhIcMmmFIA34HsBb0Pyh53MzEyFhoYqIyNDISEh7g4H8Ah8L+BtmPMHAMDHkPwBAPAxJH8AAHwMyR92zGazxo0bx6Im4L/wvYC3YcEfAAA+hsofAAAfQ/IHAMDHkPwBAPAxJH8P1aZNG40YMaJIr/Hjjz/KZDJpz549RXod4EaK43NeFG7VuIFr2Nvfh1WtWlUnT57Ubbfd5u5QgFvK4sWLVapUKXeHARQayd+HlShRQhEREe4OA7jlhIWFuTsE4H/CsL8HyM7OVv/+/VWmTBlVrFhRU6dOzXfOb7/9pv79+6tcuXIqXbq0OnXqpMOHD9udM3fuXFWtWlWlS5dWjx49NG3aNJUtW/aG13Uc9k9KSpLJZNLq1avVqFEjBQYGql27djp9+rS++OILRUVFKSQkRI8//rguXLhg6ycxMVEtWrRQ2bJlVb58eXXp0kWpqal219q8ebMaNmyogIAA3XPPPVq6dGm+KYeUlBR16tRJZcqUUYUKFdSvXz/9+uuvzv8fFLcUq9WqF154QWFhYYqIiND48ePtnj969Ki6deumMmXKKCQkRL169dIvv/xie37gwIHq3r273WtGjBihNm3a2B5/+umnio6OVmBgoMqXL68OHTooOzvb9vx7772nqKgoBQQE6M4779Q777zzhzE7DvvXqFFDr732mu17XL16dS1fvlxnzpyxxd6gQQPt2LHD9pqzZ8/qscceU+XKlVW6dGlFR0fro48+srvO+fPn1adPHwUFBalixYqaPn16vmvn5OTo+eefV+XKlRUUFKR7771XSUlJfxg/QPL3AKNGjVJycrKWLVumNWvWKCkpSbt27bI7Z+DAgdqxY4eWL1+uLVu2yDAMde7cWbm5uZKkTZs26ZlnntGzzz6rPXv26P7779frr79eqHjGjx+vt99+W5s3b9axY8fUq1cvzZgxQwsXLtTKlSu1Zs0azZw503Z+dna24uPjtWPHDn311Vfy8/NTjx49ZLVevSVmZmamunbtqujoaO3atUsTJ07Uiy++aHfNc+fOqV27dmrUqJF27NihxMRE/fLLL+rVq1eh3gNuHQsWLFBQUJC2bdumSZMm6dVXX9XatWslXf3DoFu3bkpPT1dycrLWrl2rI0eOqHfv3gXu/+TJk3rsscf0xBNP6MCBA0pKSlLPnj11bYuTDz/8UK+88opef/11HThwQG+88YbGjh2rBQsWOPU+pk+frtjYWO3evVsPPfSQ+vXrp/79+6tv377atWuXatWqpf79+9uue+nSJTVu3FgrV65USkqKnnrqKfXr10/bt2+39RkfH69NmzZp+fLlWrt2rTZu3JjvfxuGDh2qLVu2aNGiRdq3b58effRRPfjgg/mKA8COAbc6f/684e/vb3zyySe2trNnzxqBgYHGs88+axiGYRw6dMiQZGzatMl2zq+//moEBgbaXte7d2/joYcesuu7T58+Rmho6A2vnZaWZkgydu/ebRiGYaxfv96QZHz55Ze2cywWiyHJSE1NtbU9/fTTxgMPPHDDfs+cOWNIMvbv328YhmHMnj3bKF++vHHx4kXbOXPnzrW79sSJE42OHTva9XPs2DFDknHw4MEbXgu3ttatWxstWrSwa2vSpInx4osvGoZhGGvWrDFKlChhHD161Pb8t99+a0gytm/fbhiGYQwYMMDo1q2bXR/PPvus0bp1a8MwDGPnzp2GJOPHH3+8bgy1atUyFi5caNc2ceJEo1mzZn8Y97Xvp2EYRvXq1Y2+ffvaHp88edKQZIwdO9bWtmXLFkOScfLkyRv2+9BDDxnPPfecYRiGkZmZaZQqVcr497//bXv+3LlzRunSpW3X/umnn4wSJUoYx48ft+unffv2RkJCwg2vA1D5u1lqaqouX76se++919YWFhamunXr2h4fOHBAJUuWtDunfPnyqlu3rg4cOCBJOnjwoJo2bWrXt+PjgmrQoIHt3xUqVFDp0qVVs2ZNu7bTp0/bHh8+fFiPPfaYatasqZCQENWoUUPS1eHaa7E1aNBAAQEBN4xt7969Wr9+vcqUKWM77rzzTknKN4UA7/LfnzdJqlixou3zdeDAAVWtWlVVq1a1PV+vXj2VLVvW9tm/mZiYGLVv317R0dF69NFHNXfuXP3222+Sro5apaamavDgwXafvddee83pz53j90aSoqOj87Vde295eXmaOHGioqOjFRYWpjJlymj16tW2782RI0eUm5tr910JDQ21+9+G/fv3Ky8vT3Xq1LGLPzk5me8N/hAL/pDPf69iNplM+VY1m0wm25C+JHXt2lXVq1fX3LlzValSJVmtVtWvX1+XL18u8DWzsrLUtWtX/e1vf8v3XMWKFQvxLnCruNnn62b8/PxsQ+nXXJsOk64ubF27dq02b95sm7IaM2aMtm3bptKlS0u6ul7mv/+4vva6wr4Pk8l0w7Zr723y5Ml68803NWPGDEVHRysoKEgjRoxw+ntTokQJ7dy5M1+8ZcqUcSp++BYqfzerVauWSpUqpW3bttnafvvtNx06dMj2OCoqSleuXLE75+zZszp48KDq1asnSapbt66++eYbu74dHxeFa3G8/PLLat++vaKiomxV1TV169bV/v37lZOTc8PY7r77bn377beqUaOGateubXcEBQUV+fuAZ4qKitKxY8d07NgxW9t3332nc+fO2T77t99+u06ePGn3Ose9K0wmk2JjYzVhwgTt3r1b/v7+WrJkiSpUqKBKlSrpyJEj+T53kZGRRfreNm3apG7duqlv376KiYlRzZo17b73NWvWVKlSpey+KxkZGXbnNGrUSHl5eTp9+nS++PklD/4Iyd/NypQpo8GDB2vUqFFat26dUlJSNHDgQPn5/f7/mjvuuEPdunXTk08+qa+//lp79+5V3759VblyZXXr1k2SNGzYMK1atUrTpk3T4cOH9e677+qLL76wVRtFpVy5cipfvrz+/ve/64cfftC6desUHx9vd87jjz8uq9Wqp556SgcOHNDq1as1ZcoUSb9XQ3FxcUpPT9djjz2mb775RqmpqVq9erUGDRqkvLy8In0P8FwdOnRQdHS0+vTpo127dmn79u3q37+/WrdurXvuuUeS1K5dO+3YsUPvv/++Dh8+rHHjxiklJcXWx7Zt2/TGG29ox44dOnr0qBYvXqwzZ84oKipKkjRhwgRZLBa99dZbOnTokPbv36958+Zp2rRpRfre7rjjDtuIxIEDB/T000/b/YohODhYAwYM0KhRo7R+/Xp9++23Gjx4sPz8/Gzfmzp16qhPnz7q37+/Fi9erLS0NG3fvl0Wi0UrV64s0vhxayP5e4DJkyerZcuW6tq1qzp06KAWLVqocePGdufMmzdPjRs3VpcuXdSsWTMZhqFVq1bZhhVjY2M1Z84cTZs2TTExMUpMTNTIkSPt5tmLgp+fnxYtWqSdO3eqfv36GjlypCZPnmx3TkhIiD7//HPt2bNHDRs21JgxY/TKK69Iki2+SpUqadOmTcrLy1PHjh0VHR2tESNGqGzZsnZ/CMG3mEwmLVu2TOXKlVOrVq3UoUMH1axZUx9//LHtnAceeEBjx47VCy+8oCZNmuj8+fPq37+/7fmQkBBt2LBBnTt3Vp06dfTyyy9r6tSp6tSpkyRpyJAheu+99zRv3jxFR0erdevWmj9/fpFX/i+//LLuvvtuPfDAA2rTpo0iIiLy/WRx2rRpatasmbp06aIOHTooNjbW9pPEa+bNm6f+/fvrueeeU926ddW9e3d98803qlatWpHGj1sbt/T1Yk8++aS+//57bdy40d2h5PPhhx9q0KBBysjIUGBgoLvDAW4J2dnZqly5sqZOnarBgwe7Oxzcwljw50WmTJmi+++/X0FBQfriiy+0YMGCm25WUlzef/991axZU5UrV9bevXv14osvqlevXiR+4A/s3r1b33//vZo2baqMjAy9+uqrkmSb7gMKi+TvRbZv365Jkybp/Pnzqlmzpt566y0NGTLE3WFJkk6dOqVXXnlFp06dUsWKFfXoo48WehMiwJdMmTJFBw8elL+/vxo3bqyNGzdyPw78zxj2BwDAx7CSCgAAH0PyBwDAx5D8AQDwMSR/AAB8DMkfAAAfQ/IHAMDHkPwBAPAxJH8AAHwMyR8AAB/z//2O+I26RpYSAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Convert to numpy arrays\n", + "img_feat_dog_np = np.array(img_feat_dog[0])\n", + "img_feat_house_np = np.array(img_feat_house[0])\n", + "text_feat_dog_np = np.array(text_feat_dog[0])\n", + "text_feat_house_np = np.array(text_feat_house[0])\n", + "\n", + "# Compute similarity\n", + "similarities = np.array([\n", + " [text_feat_dog_np @ img_feat_dog_np.T][0][0], \n", + " [text_feat_dog_np @ img_feat_house_np.T][0][0],\n", + " [text_feat_house_np @ img_feat_dog_np.T][0][0], \n", + " [text_feat_house_np @ img_feat_house_np.T][0][0]\n", + "]).reshape(2, 2)\n", + "\n", + "# Ensure similarities is of shape (2, 2)\n", + "print(similarities.shape) # Expected: (2, 2)\n", + "\n", + "# Plot heatmap\n", + "sns.heatmap(similarities, annot=True, cmap='viridis', xticklabels=['dog image', 'house image'], yticklabels=['dog text', 'house text'])\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py index fb9c7649d3..9098acd527 100644 --- a/libs/langchain/langchain/embeddings/__init__.py +++ b/libs/langchain/langchain/embeddings/__init__.py @@ -53,6 +53,7 @@ from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings from langchain.embeddings.nlpcloud import NLPCloudEmbeddings from langchain.embeddings.octoai_embeddings import OctoAIEmbeddings from langchain.embeddings.ollama import OllamaEmbeddings +from langchain.embeddings.open_clip import OpenCLIPEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.sagemaker_endpoint import SagemakerEndpointEmbeddings from langchain.embeddings.self_hosted import SelfHostedEmbeddings @@ -117,6 +118,7 @@ __all__ = [ "QianfanEmbeddingsEndpoint", "JohnSnowLabsEmbeddings", "VoyageEmbeddings", + "OpenCLIPEmbeddings", ] diff --git a/libs/langchain/langchain/embeddings/open_clip.py b/libs/langchain/langchain/embeddings/open_clip.py new file mode 100644 index 0000000000..c19cabb253 --- /dev/null +++ b/libs/langchain/langchain/embeddings/open_clip.py @@ -0,0 +1,56 @@ +from typing import Any, Dict, List + +import numpy as np + +from langchain.pydantic_v1 import BaseModel, root_validator +from langchain.schema.embeddings import Embeddings + + +class OpenCLIPEmbeddings(BaseModel, Embeddings): + model: Any + preprocess: Any + tokenizer: Any + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that open_clip and torch libraries are installed.""" + try: + import open_clip + + model_name = "ViT-B-32" + checkpoint = "laion2b_s34b_b79k" + model, _, preprocess = open_clip.create_model_and_transforms( + model_name=model_name, pretrained=checkpoint + ) + tokenizer = open_clip.get_tokenizer(model_name) + values["model"] = model + values["preprocess"] = preprocess + values["tokenizer"] = tokenizer + + except ImportError: + raise ImportError( + "Please ensure both open_clip and torch libraries are installed. " + "pip install open_clip_torch torch" + ) + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + text_features = [ + self.model.encode_text(self.tokenizer(text)).tolist() for text in texts + ] + return text_features + + def embed_query(self, text: str) -> List[float]: + return self.embed_documents([text])[0] + + def embed_image(self, images: List[np.ndarray]) -> List[List[float]]: + try: + from PIL import Image as _PILImage + except ImportError: + raise ImportError("Please install the PIL library: pip install pillow") + pil_images = [_PILImage.fromarray(image) for image in images] + image_features = [ + self.model.encode_image(self.preprocess(pil_image).unsqueeze(0)).tolist() + for pil_image in pil_images + ] + return image_features diff --git a/libs/langchain/tests/unit_tests/embeddings/test_imports.py b/libs/langchain/tests/unit_tests/embeddings/test_imports.py index 8d05d3634e..dd4f1aebe8 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_imports.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_imports.py @@ -48,6 +48,7 @@ EXPECTED_ALL = [ "QianfanEmbeddingsEndpoint", "JohnSnowLabsEmbeddings", "VoyageEmbeddings", + "OpenCLIPEmbeddings", ]