From bb7ac9edb58c05b420e6049eec6829741bfa40e8 Mon Sep 17 00:00:00 2001 From: wenmeng zhou Date: Mon, 12 Jun 2023 12:14:20 +0800 Subject: [PATCH] add dashscope text embedding (#5929) #### What I do Adding embedding api for [DashScope](https://help.aliyun.com/product/610100.html), which is the DAMO Academy's multilingual text unified vector model based on the LLM base. It caters to multiple mainstream languages worldwide and offers high-quality vector services, helping developers quickly transform text data into high-quality vector data. Currently supported languages include Chinese, English, Spanish, French, Portuguese, Indonesian, and more. #### Who can review? Models - @hwchase17 - @agola11 --------- Co-authored-by: Harrison Chase --- .../text_embedding/examples/dashscope.ipynb | 83 ++++++++++ langchain/embeddings/__init__.py | 2 + langchain/embeddings/dashscope.py | 155 ++++++++++++++++++ .../embeddings/test_dashscope.py | 55 +++++++ 4 files changed, 295 insertions(+) create mode 100644 docs/modules/models/text_embedding/examples/dashscope.ipynb create mode 100644 langchain/embeddings/dashscope.py create mode 100644 tests/integration_tests/embeddings/test_dashscope.py diff --git a/docs/modules/models/text_embedding/examples/dashscope.ipynb b/docs/modules/models/text_embedding/examples/dashscope.ipynb new file mode 100644 index 0000000000..7095ad5dc7 --- /dev/null +++ b/docs/modules/models/text_embedding/examples/dashscope.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# DashScope\n", + "\n", + "Let's load the DashScope Embedding class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings import DashScopeEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "embeddings = DashScopeEmbeddings(model='text-embedding-v1', dashscope_api_key='your-dashscope-api-key')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text = \"This is a test document.\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_result = embeddings.embed_query(text)\n", + "print(query_result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "doc_results = embeddings.embed_documents([\"foo\"])\n", + "print(doc_results)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "chatgpt", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index a54ea9aa47..b68769c398 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -8,6 +8,7 @@ from langchain.embeddings.aleph_alpha import ( ) from langchain.embeddings.bedrock import BedrockEmbeddings from langchain.embeddings.cohere import CohereEmbeddings +from langchain.embeddings.dashscope import DashScopeEmbeddings from langchain.embeddings.deepinfra import DeepInfraEmbeddings from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings from langchain.embeddings.embaas import EmbaasEmbeddings @@ -61,6 +62,7 @@ __all__ = [ "VertexAIEmbeddings", "BedrockEmbeddings", "DeepInfraEmbeddings", + "DashScopeEmbeddings", "EmbaasEmbeddings", ] diff --git a/langchain/embeddings/dashscope.py b/langchain/embeddings/dashscope.py new file mode 100644 index 0000000000..1db6dd1d53 --- /dev/null +++ b/langchain/embeddings/dashscope.py @@ -0,0 +1,155 @@ +"""Wrapper around DashScope embedding models.""" +from __future__ import annotations + +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Optional, +) + +from pydantic import BaseModel, Extra, root_validator +from requests.exceptions import HTTPError +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +def _create_retry_decorator(embeddings: DashScopeEmbeddings) -> Callable[[Any], Any]: + multiplier = 1 + min_seconds = 1 + max_seconds = 4 + # Wait 2^x * 1 second between each retry starting with + # 1 seconds, then up to 4 seconds, then 4 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier, min=min_seconds, max=max_seconds), + retry=(retry_if_exception_type(HTTPError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def embed_with_retry(embeddings: DashScopeEmbeddings, **kwargs: Any) -> Any: + """Use tenacity to retry the embedding call.""" + retry_decorator = _create_retry_decorator(embeddings) + + @retry_decorator + def _embed_with_retry(**kwargs: Any) -> Any: + resp = embeddings.client.call(**kwargs) + if resp.status_code == 200: + return resp.output["embeddings"] + elif resp.status_code in [400, 401]: + raise ValueError( + f"status_code: {resp.status_code} \n " + f"code: {resp.code} \n message: {resp.message}" + ) + else: + raise HTTPError( + f"HTTP error occurred: status_code: {resp.status_code} \n " + f"code: {resp.code} \n message: {resp.message}" + ) + + return _embed_with_retry(**kwargs) + + +class DashScopeEmbeddings(BaseModel, Embeddings): + """Wrapper around DashScope embedding models. + + To use, you should have the ``dashscope`` python package installed, and the + environment variable ``DASHSCOPE_API_KEY`` set with your API key or pass it + as a named parameter to the constructor. + + Example: + .. code-block:: python + + from langchain.embeddings import DashScopeEmbeddings + embeddings = DashScopeEmbeddings(dashscope_api_key="my-api-key") + + Example: + .. code-block:: python + + import os + os.environ["DASHSCOPE_API_KEY"] = "your DashScope API KEY" + + from langchain.embeddings.dashscope import DashScopeEmbeddings + embeddings = DashScopeEmbeddings( + model="text-embedding-v1", + ) + text = "This is a test query." + query_result = embeddings.embed_query(text) + + """ + + client: Any #: :meta private: + model: str = "text-embedding-v1" + dashscope_api_key: Optional[str] = None + """Maximum number of retries to make when generating.""" + max_retries: int = 5 + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + import dashscope + + """Validate that api key and python package exists in environment.""" + values["dashscope_api_key"] = get_from_dict_or_env( + values, "dashscope_api_key", "DASHSCOPE_API_KEY" + ) + dashscope.api_key = values["dashscope_api_key"] + try: + import dashscope + + values["client"] = dashscope.TextEmbedding + except ImportError: + raise ImportError( + "Could not import dashscope python package. " + "Please install it with `pip install dashscope`." + ) + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call out to DashScope's embedding endpoint for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + embeddings = embed_with_retry( + self, input=texts, text_type="document", model=self.model + ) + embedding_list = [item["embedding"] for item in embeddings] + return embedding_list + + def embed_query(self, text: str) -> List[float]: + """Call out to DashScope's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embedding = embed_with_retry( + self, input=text, text_type="query", model=self.model + )[0]["embedding"] + return embedding diff --git a/tests/integration_tests/embeddings/test_dashscope.py b/tests/integration_tests/embeddings/test_dashscope.py new file mode 100644 index 0000000000..f61c3805e1 --- /dev/null +++ b/tests/integration_tests/embeddings/test_dashscope.py @@ -0,0 +1,55 @@ +"""Test dashscope embeddings.""" +import numpy as np + +from langchain.embeddings.dashscope import DashScopeEmbeddings + + +def test_dashscope_embedding_documents() -> None: + """Test dashscope embeddings.""" + documents = ["foo bar"] + embedding = DashScopeEmbeddings(model="text-embedding-v1") + output = embedding.embed_documents(documents) + assert len(output) == 1 + assert len(output[0]) == 1536 + + +def test_dashscope_embedding_documents_multiple() -> None: + """Test dashscope embeddings.""" + documents = ["foo bar", "bar foo", "foo"] + embedding = DashScopeEmbeddings(model="text-embedding-v1") + output = embedding.embed_documents(documents) + assert len(output) == 3 + assert len(output[0]) == 1536 + assert len(output[1]) == 1536 + assert len(output[2]) == 1536 + + +def test_dashscope_embedding_query() -> None: + """Test dashscope embeddings.""" + document = "foo bar" + embedding = DashScopeEmbeddings(model="text-embedding-v1") + output = embedding.embed_query(document) + assert len(output) == 1536 + + +def test_dashscope_embedding_with_empty_string() -> None: + """Test dashscope embeddings with empty string.""" + import dashscope + + document = ["", "abc"] + embedding = DashScopeEmbeddings(model="text-embedding-v1") + output = embedding.embed_documents(document) + assert len(output) == 2 + assert len(output[0]) == 1536 + expected_output = dashscope.TextEmbedding.call( + input="", model="text-embedding-v1", text_type="document" + ).output["embeddings"][0]["embedding"] + assert np.allclose(output[0], expected_output) + assert len(output[1]) == 1536 + + +if __name__ == "__main__": + test_dashscope_embedding_documents() + test_dashscope_embedding_documents_multiple() + test_dashscope_embedding_query() + test_dashscope_embedding_with_empty_string()