From 5cdd9ab7e1567d3a55e76ec2b80423083d72b5e4 Mon Sep 17 00:00:00 2001 From: Archon Date: Thu, 25 May 2023 21:57:49 +0800 Subject: [PATCH] Add MiniMax embeddings (#5174) - Add support for MiniMax embeddings Doc: [MiniMax embeddings](https://api.minimax.chat/document/guides/embeddings?id=6464722084cdc277dfaa966a) --------- Co-authored-by: Archon Co-authored-by: Dev 2049 --- .../text_embedding/examples/minimax.ipynb | 145 ++++++++++++++++ langchain/embeddings/__init__.py | 2 + langchain/embeddings/minimax.py | 163 ++++++++++++++++++ 3 files changed, 310 insertions(+) create mode 100644 docs/modules/models/text_embedding/examples/minimax.ipynb create mode 100644 langchain/embeddings/minimax.py diff --git a/docs/modules/models/text_embedding/examples/minimax.ipynb b/docs/modules/models/text_embedding/examples/minimax.ipynb new file mode 100644 index 0000000000..bcfbe6912d --- /dev/null +++ b/docs/modules/models/text_embedding/examples/minimax.ipynb @@ -0,0 +1,145 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MiniMax\n", + "\n", + "[MiniMax](https://api.minimax.chat/document/guides/embeddings?id=6464722084cdc277dfaa966a) offers an embeddings service.\n", + "\n", + "This example goes over how to use LangChain to interact with MiniMax Inference for text embedding." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2023-05-24T15:13:15.397075Z", + "start_time": "2023-05-24T15:13:15.387540Z" + } + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"MINIMAX_GROUP_ID\"] = \"MINIMAX_GROUP_ID\"\n", + "os.environ[\"MINIMAX_API_KEY\"] = \"MINIMAX_API_KEY\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2023-05-24T15:13:17.176956Z", + "start_time": "2023-05-24T15:13:15.399076Z" + } + }, + "outputs": [], + "source": [ + "from langchain.embeddings import MiniMaxEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2023-05-24T15:13:17.193751Z", + "start_time": "2023-05-24T15:13:17.182053Z" + } + }, + "outputs": [], + "source": [ + "embeddings = MiniMaxEmbeddings()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2023-05-24T15:13:17.844903Z", + "start_time": "2023-05-24T15:13:17.198751Z" + } + }, + "outputs": [], + "source": [ + "query_text = \"This is a test query.\"\n", + "query_result = embeddings.embed_query(query_text)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2023-05-24T15:13:18.605339Z", + "start_time": "2023-05-24T15:13:17.845906Z" + } + }, + "outputs": [], + "source": [ + "document_text = \"This is a test document.\"\n", + "document_result = embeddings.embed_documents([document_text])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2023-05-24T15:13:18.620432Z", + "start_time": "2023-05-24T15:13:18.608335Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cosine similarity between document and query: 0.1573236279277012\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "query_numpy = np.array(query_result)\n", + "document_numpy = np.array(document_result[0])\n", + "similarity = np.dot(query_numpy, document_numpy) / (np.linalg.norm(query_numpy)*np.linalg.norm(document_numpy))\n", + "print(f\"Cosine similarity between document and query: {similarity}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index d0db607856..dc6b867269 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -17,6 +17,7 @@ from langchain.embeddings.huggingface import ( from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings from langchain.embeddings.jina import JinaEmbeddings from langchain.embeddings.llamacpp import LlamaCppEmbeddings +from langchain.embeddings.minimax import MiniMaxEmbeddings from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings @@ -53,6 +54,7 @@ __all__ = [ "AlephAlphaSymmetricSemanticEmbedding", "SentenceTransformerEmbeddings", "GooglePalmEmbeddings", + "MiniMaxEmbeddings", "VertexAIEmbeddings", ] diff --git a/langchain/embeddings/minimax.py b/langchain/embeddings/minimax.py new file mode 100644 index 0000000000..c33ca0f292 --- /dev/null +++ b/langchain/embeddings/minimax.py @@ -0,0 +1,163 @@ +"""Wrapper around MiniMax APIs.""" +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Optional + +import requests +from pydantic import BaseModel, Extra, root_validator +from tenacity import ( + before_sleep_log, + retry, + stop_after_attempt, + wait_exponential, +) + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +def _create_retry_decorator() -> Callable[[Any], Any]: + """Returns a tenacity retry decorator.""" + + multiplier = 1 + min_seconds = 1 + max_seconds = 4 + max_retries = 6 + + return retry( + reraise=True, + stop=stop_after_attempt(max_retries), + wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def embed_with_retry(embeddings: MiniMaxEmbeddings, *args: Any, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator() + + @retry_decorator + def _embed_with_retry(*args: Any, **kwargs: Any) -> Any: + return embeddings.embed(*args, **kwargs) + + return _embed_with_retry(*args, **kwargs) + + +class MiniMaxEmbeddings(BaseModel, Embeddings): + """Wrapper around MiniMax's embedding inference service. + + To use, you should have the environment variable ``MINIMAX_GROUP_ID`` and + ``MINIMAX_API_KEY`` set with your API token, or pass it as a named parameter to + the constructor. + + Example: + .. code-block:: python + + from langchain.embeddings import MiniMaxEmbeddings + embeddings = MiniMaxEmbeddings() + + query_text = "This is a test query." + query_result = embeddings.embed_query(query_text) + + document_text = "This is a test document." + document_result = embeddings.embed_documents([document_text]) + + """ + + endpoint_url: str = "https://api.minimax.chat/v1/embeddings" + """Endpoint URL to use.""" + model: str = "embo-01" + """Embeddings model name to use.""" + embed_type_db: str = "db" + """For embed_documents""" + embed_type_query: str = "query" + """For embed_query""" + + minimax_group_id: Optional[str] = None + """Group ID for MiniMax API.""" + minimax_api_key: Optional[str] = None + """API Key for MiniMax API.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that group id and api key exists in environment.""" + minimax_group_id = get_from_dict_or_env( + values, "minimax_group_id", "MINIMAX_GROUP_ID" + ) + minimax_api_key = get_from_dict_or_env( + values, "minimax_api_key", "MINIMAX_API_KEY" + ) + values["minimax_group_id"] = minimax_group_id + values["minimax_api_key"] = minimax_api_key + return values + + def embed( + self, + texts: List[str], + embed_type: str, + ) -> List[List[float]]: + payload = { + "model": self.model, + "type": embed_type, + "texts": texts, + } + + # HTTP headers for authorization + headers = { + "Authorization": f"Bearer {self.minimax_api_key}", + "Content-Type": "application/json", + } + + params = { + "GroupId": self.minimax_group_id, + } + + # send request + response = requests.post( + self.endpoint_url, params=params, headers=headers, json=payload + ) + parsed_response = response.json() + + # check for errors + if parsed_response["base_resp"]["status_code"] != 0: + raise ValueError( + f"MiniMax API returned an error: {parsed_response['base_resp']}" + ) + + embeddings = parsed_response["vectors"] + + return embeddings + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed documents using a MiniMax embedding endpoint. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + embeddings = embed_with_retry(self, texts=texts, embed_type=self.embed_type_db) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Embed a query using a MiniMax embedding endpoint. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + embeddings = embed_with_retry( + self, texts=[text], embed_type=self.embed_type_query + ) + return embeddings[0]