diff --git a/libs/community/langchain_community/embeddings/gradient_ai.py b/libs/community/langchain_community/embeddings/gradient_ai.py index d5c9bf9a49..6fd05d3c69 100644 --- a/libs/community/langchain_community/embeddings/gradient_ai.py +++ b/libs/community/langchain_community/embeddings/gradient_ai.py @@ -1,15 +1,9 @@ -import asyncio -import logging -import os -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, List, Optional, Tuple - -import aiohttp -import numpy as np -import requests +from typing import Any, Dict, List, Optional + from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.utils import get_from_dict_or_env +from packaging.version import parse __all__ = ["GradientEmbeddings"] @@ -49,6 +43,9 @@ class GradientEmbeddings(BaseModel, Embeddings): gradient_api_url: str = "https://api.gradient.ai/api" """Endpoint URL to use.""" + query_prompt_for_retrieval: Optional[str] = None + """Query pre-prompt""" + client: Any = None #: :meta private: """Gradient client.""" @@ -72,21 +69,24 @@ class GradientEmbeddings(BaseModel, Embeddings): values["gradient_api_url"] = get_from_dict_or_env( values, "gradient_api_url", "GRADIENT_API_URL" ) + try: + import gradientai + except ImportError: + raise ImportError( + 'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.' + ) - values["client"] = TinyAsyncGradientEmbeddingClient( + if parse(gradientai.__version__) < parse("1.4.0"): + raise ImportError( + 'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.' + ) + + gradient = gradientai.Gradient( access_token=values["gradient_access_token"], workspace_id=values["gradient_workspace_id"], host=values["gradient_api_url"], ) - try: - import gradientai # noqa - except ImportError: - logging.warning( - "DeprecationWarning: `GradientEmbeddings` will use " - "`pip install gradientai` in future releases of langchain." - ) - except Exception: - pass + values["client"] = gradient.get_embeddings_model(slug=values["model"]) return values @@ -99,11 +99,11 @@ class GradientEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - embeddings = self.client.embed( - model=self.model, - texts=texts, - ) - return embeddings + inputs = [{"input": text} for text in texts] + + result = self.client.embed(inputs=inputs).embeddings + + return [e.embedding for e in result] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Async call out to Gradient's embedding endpoint. @@ -114,11 +114,11 @@ class GradientEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - embeddings = await self.client.aembed( - model=self.model, - texts=texts, - ) - return embeddings + inputs = [{"input": text} for text in texts] + + result = (await self.client.aembed(inputs=inputs)).embeddings + + return [e.embedding for e in result] def embed_query(self, text: str) -> List[float]: """Call out to Gradient's embedding endpoint. @@ -129,7 +129,12 @@ class GradientEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - return self.embed_documents([text])[0] + query = ( + f"{self.query_prompt_for_retrieval} {text}" + if self.query_prompt_for_retrieval + else text + ) + return self.embed_documents([query])[0] async def aembed_query(self, text: str) -> List[float]: """Async call out to Gradient's embedding endpoint. @@ -140,240 +145,22 @@ class GradientEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - embeddings = await self.aembed_documents([text]) + query = ( + f"{self.query_prompt_for_retrieval} {text}" + if self.query_prompt_for_retrieval + else text + ) + embeddings = await self.aembed_documents([query]) return embeddings[0] class TinyAsyncGradientEmbeddingClient: #: :meta private: - """A helper tool to embed Gradient. Not part of Langchain's or Gradients stable API, - direct use discouraged. - - To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your - API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace, - or alternatively provide them as keywords to the constructor of this class. - - Example: - .. code-block:: python - - - mini_client = TinyAsyncGradientEmbeddingClient( - workspace_id="12345614fc0_workspace", - access_token="gradientai-access_token", - ) - embeds = mini_client.embed( - model="bge-large", - text=["doc1", "doc2"] - ) - # or - embeds = await mini_client.aembed( - model="bge-large", - text=["doc1", "doc2"] - ) + """Deprecated, TinyAsyncGradientEmbeddingClient was removed. + This class is just for backwards compatibility with older versions + of langchain_community. + It might be entirely removed in the future. """ - def __init__( - self, - access_token: Optional[str] = None, - workspace_id: Optional[str] = None, - host: str = "https://api.gradient.ai/api", - aiosession: Optional[aiohttp.ClientSession] = None, - ) -> None: - self.access_token = access_token or os.environ.get( - "GRADIENT_ACCESS_TOKEN", None - ) - self.workspace_id = workspace_id or os.environ.get( - "GRADIENT_WORKSPACE_ID", None - ) - self.host = host - self.aiosession = aiosession - - if self.access_token is None or len(self.access_token) < 10: - raise ValueError( - "env variable `GRADIENT_ACCESS_TOKEN` or " - " param `access_token` must be set " - ) - - if self.workspace_id is None or len(self.workspace_id) < 3: - raise ValueError( - "env variable `GRADIENT_WORKSPACE_ID` or " - " param `workspace_id` must be set" - ) - - if self.host is None or len(self.host) < 3: - raise ValueError(" param `host` must be set to a valid url") - self._batch_size = 128 - - @staticmethod - def _permute( - texts: List[str], sorter: Callable = len - ) -> Tuple[List[str], Callable]: - """Sort texts in ascending order, and - delivers a lambda expr, which can sort a same length list - https://github.com/UKPLab/sentence-transformers/blob/ - c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/SentenceTransformer.py#L156 - - Args: - texts (List[str]): _description_ - sorter (Callable, optional): _description_. Defaults to len. - - Returns: - Tuple[List[str], Callable]: _description_ - - Example: - ``` - texts = ["one","three","four"] - perm_texts, undo = self._permute(texts) - texts == undo(perm_texts) - ``` - """ - - if len(texts) == 1: - # special case query - return texts, lambda t: t - length_sorted_idx = np.argsort([-sorter(sen) for sen in texts]) - texts_sorted = [texts[idx] for idx in length_sorted_idx] - - return texts_sorted, lambda unsorted_embeddings: [ # noqa E731 - unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx) - ] - - def _batch(self, texts: List[str]) -> List[List[str]]: - """ - splits Lists of text parts into batches of size max `self._batch_size` - When encoding vector database, - - Args: - texts (List[str]): List of sentences - self._batch_size (int, optional): max batch size of one request. - - Returns: - List[List[str]]: Batches of List of sentences - """ - if len(texts) == 1: - # special case query - return [texts] - batches = [] - for start_index in range(0, len(texts), self._batch_size): - batches.append(texts[start_index : start_index + self._batch_size]) - return batches - - @staticmethod - def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]: - if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1: - # special case query - return batch_of_texts[0] - texts = [] - for sublist in batch_of_texts: - texts.extend(sublist) - return texts - - def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]: - """Build the kwargs for the Post request, used by sync - - Args: - model (str): _description_ - texts (List[str]): _description_ - - Returns: - Dict[str, Collection[str]]: _description_ - """ - return dict( - url=f"{self.host}/embeddings/{model}", - headers={ - "authorization": f"Bearer {self.access_token}", - "x-gradient-workspace-id": f"{self.workspace_id}", - "accept": "application/json", - "content-type": "application/json", - }, - json=dict( - inputs=[{"input": i} for i in texts], - ), - ) - - def _sync_request_embed( - self, model: str, batch_texts: List[str] - ) -> List[List[float]]: - response = requests.post( - **self._kwargs_post_request(model=model, texts=batch_texts) - ) - if response.status_code != 200: - raise Exception( - f"Gradient returned an unexpected response with status " - f"{response.status_code}: {response.text}" - ) - return [e["embedding"] for e in response.json()["embeddings"]] - - def embed(self, model: str, texts: List[str]) -> List[List[float]]: - """call the embedding of model - - Args: - model (str): to embedding model - texts (List[str]): List of sentences to embed. - - Returns: - List[List[float]]: List of vectors for each sentence - """ - perm_texts, unpermute_func = self._permute(texts) - perm_texts_batched = self._batch(perm_texts) - - # Request - map_args = ( - self._sync_request_embed, - [model] * len(perm_texts_batched), - perm_texts_batched, - ) - if len(perm_texts_batched) == 1: - embeddings_batch_perm = list(map(*map_args)) - else: - with ThreadPoolExecutor(32) as p: - embeddings_batch_perm = list(p.map(*map_args)) - - embeddings_perm = self._unbatch(embeddings_batch_perm) - embeddings = unpermute_func(embeddings_perm) - return embeddings - - async def _async_request( - self, session: aiohttp.ClientSession, kwargs: Dict[str, Any] - ) -> List[List[float]]: - async with session.post(**kwargs) as response: - if response.status != 200: - raise Exception( - f"Gradient returned an unexpected response with status " - f"{response.status}: {response.text}" - ) - embedding = (await response.json())["embeddings"] - return [e["embedding"] for e in embedding] - - async def aembed(self, model: str, texts: List[str]) -> List[List[float]]: - """call the embedding of model, async method - - Args: - model (str): to embedding model - texts (List[str]): List of sentences to embed. - - Returns: - List[List[float]]: List of vectors for each sentence - """ - perm_texts, unpermute_func = self._permute(texts) - perm_texts_batched = self._batch(perm_texts) - - # Request - if self.aiosession is None: - self.aiosession = aiohttp.ClientSession( - trust_env=True, connector=aiohttp.TCPConnector(limit=32) - ) - async with self.aiosession as session: - embeddings_batch_perm = await asyncio.gather( - *[ - self._async_request( - session=session, - **self._kwargs_post_request(model=model, texts=t), - ) - for t in perm_texts_batched - ] - ) - - embeddings_perm = self._unbatch(embeddings_batch_perm) - embeddings = unpermute_func(embeddings_perm) - return embeddings + def __init__(self, *args, **kwargs) -> None: + raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.") diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index 160e328cff..e4dffd10f5 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -1,5 +1,17 @@ # This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +[[package]] +name = "aenum" +version = "3.1.15" +description = "Advanced Enumerations (compatible with Python's stdlib Enum), NamedTuples, and NamedConstants" +optional = true +python-versions = "*" +files = [ + {file = "aenum-3.1.15-py2-none-any.whl", hash = "sha256:27b1710b9d084de6e2e695dab78fe9f269de924b51ae2850170ee7e1ca6288a5"}, + {file = "aenum-3.1.15-py3-none-any.whl", hash = "sha256:e0dfaeea4c2bd362144b87377e2c61d91958c5ed0b4daf89cb6f45ae23af6288"}, + {file = "aenum-3.1.15.tar.gz", hash = "sha256:8cbd76cd18c4f870ff39b24284d3ea028fbe8731a58df3aa581e434c575b9559"}, +] + [[package]] name = "aiodns" version = "3.1.1" @@ -2352,6 +2364,23 @@ test = ["aiofiles", "aiohttp (>=3.7.1,<3.9.0)", "botocore (>=1.21,<2)", "mock (= test-no-transport = ["aiofiles", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==6.2.5)", "pytest-asyncio (==0.16.0)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "vcrpy (==4.0.2)"] websockets = ["websockets (>=10,<11)", "websockets (>=9,<10)"] +[[package]] +name = "gradientai" +version = "1.4.0" +description = "Gradient AI API" +optional = true +python-versions = ">=3.8.1,<4.0.0" +files = [ + {file = "gradientai-1.4.0-py3-none-any.whl", hash = "sha256:58b74151e4bee534d438509303bcca3a9b84d17dafff31c206353489b54fcbfa"}, + {file = "gradientai-1.4.0.tar.gz", hash = "sha256:98b9e0894530c6b7c675a113010dca7f7f7c399e02c46c0fb5532bf9fc1609f4"}, +] + +[package.dependencies] +aenum = ">=3.1.11" +pydantic = ">=1.10.5,<2.0.0" +python-dateutil = ">=2.8.2" +urllib3 = ">=1.25.3" + [[package]] name = "graphql-core" version = "3.2.3" @@ -3023,7 +3052,6 @@ files = [ {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"}, {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"}, {file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"}, - {file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"}, @@ -3421,7 +3449,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.0" +version = "0.1.1" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -5529,6 +5557,7 @@ files = [ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"}, {file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"}, {file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"}, + {file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"}, @@ -8480,9 +8509,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] cli = ["typer"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "e3bacf389a13d283c4dd29e3a673e1863826b4e98785c666fefc10cf714c2f6f" +content-hash = "ab4b1efe33110b575d2fb65bd5ecb90e92d1bd83dd5eac87080e4d07268df72f" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index b5cfa09489..9a16ed378a 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -28,6 +28,7 @@ openai = {version = "<2", optional = true} arxiv = {version = "^1.4", optional = true} pypdf = {version = "^3.4.0", optional = true} aleph-alpha-client = {version="^2.15.0", optional = true} +gradientai = {version="^1.4.0", optional = true} pgvector = {version = "^0.1.6", optional = true} atlassian-python-api = {version = "^3.36.0", optional=true} html2text = {version="^2020.1.16", optional=true} @@ -203,6 +204,7 @@ extended_testing = [ "telethon", "psychicapi", "gql", + "gradientai", "requests-toolbelt", "html2text", "numexpr", diff --git a/libs/community/tests/unit_tests/embeddings/test_gradient_ai.py b/libs/community/tests/unit_tests/embeddings/test_gradient_ai.py index c5001b1ea8..fa57109ef0 100644 --- a/libs/community/tests/unit_tests/embeddings/test_gradient_ai.py +++ b/libs/community/tests/unit_tests/embeddings/test_gradient_ai.py @@ -1,7 +1,8 @@ -from typing import Dict +import sys +from typing import Any, Dict, List +from unittest.mock import MagicMock, patch import pytest -from pytest_mock import MockerFixture from langchain_community.embeddings import GradientEmbeddings @@ -11,117 +12,93 @@ _GRADIENT_WORKSPACE_ID = "valid_workspace_12345" _GRADIENT_BASE_URL = "https://api.gradient.ai/api" _DOCUMENTS = [ "pizza", - "another pizza", + "another long pizza", "a document", - "another pizza", + "another long pizza", "super long document with many tokens", ] -class MockResponse: - def __init__(self, json_data: Dict, status_code: int): - self.json_data = json_data - self.status_code = status_code - - def json(self) -> Dict: - return self.json_data - - -def mocked_requests_post( - url: str, - headers: dict, - json: dict, -) -> MockResponse: - assert url.startswith(_GRADIENT_BASE_URL) - assert _MODEL_ID in url - assert json - assert headers - - assert headers.get("authorization") == f"Bearer {_GRADIENT_SECRET}" - assert headers.get("x-gradient-workspace-id") == f"{_GRADIENT_WORKSPACE_ID}" - - assert "inputs" in json and "input" in json["inputs"][0] - embeddings = [] - for inp in json["inputs"]: - # verify correct ordering - inp = inp["input"] - if "pizza" in inp: - v = [1.0, 0.0, 0.0] - elif "document" in inp: - v = [0.0, 0.9, 0.0] - else: - v = [0.0, 0.0, -1.0] - if len(inp) > 10: - v[2] += 0.1 - embeddings.append({"embedding": v}) - - return MockResponse( - json_data={"embeddings": embeddings}, - status_code=200, - ) - - -def test_gradient_llm_sync( - mocker: MockerFixture, -) -> None: - mocker.patch("requests.post", side_effect=mocked_requests_post) - - embedder = GradientEmbeddings( - gradient_api_url=_GRADIENT_BASE_URL, - gradient_access_token=_GRADIENT_SECRET, - gradient_workspace_id=_GRADIENT_WORKSPACE_ID, - model=_MODEL_ID, - ) - assert embedder.gradient_access_token == _GRADIENT_SECRET - assert embedder.gradient_api_url == _GRADIENT_BASE_URL - assert embedder.gradient_workspace_id == _GRADIENT_WORKSPACE_ID - assert embedder.model == _MODEL_ID - - response = embedder.embed_documents(_DOCUMENTS) - want = [ - [1.0, 0.0, 0.0], # pizza - [1.0, 0.0, 0.1], # pizza + long - [0.0, 0.9, 0.0], # doc - [1.0, 0.0, 0.1], # pizza + long - [0.0, 0.9, 0.1], # doc + long - ] - - assert response == want - - -def test_gradient_llm_large_batch_size( - mocker: MockerFixture, -) -> None: - mocker.patch("requests.post", side_effect=mocked_requests_post) - - embedder = GradientEmbeddings( - gradient_api_url=_GRADIENT_BASE_URL, - gradient_access_token=_GRADIENT_SECRET, - gradient_workspace_id=_GRADIENT_WORKSPACE_ID, - model=_MODEL_ID, - ) - assert embedder.gradient_access_token == _GRADIENT_SECRET - assert embedder.gradient_api_url == _GRADIENT_BASE_URL - assert embedder.gradient_workspace_id == _GRADIENT_WORKSPACE_ID - assert embedder.model == _MODEL_ID - - response = embedder.embed_documents(_DOCUMENTS * 1024) - want = [ - [1.0, 0.0, 0.0], # pizza - [1.0, 0.0, 0.1], # pizza + long - [0.0, 0.9, 0.0], # doc - [1.0, 0.0, 0.1], # pizza + long - [0.0, 0.9, 0.1], # doc + long - ] * 1024 - - assert response == want - - -def test_gradient_wrong_setup( - mocker: MockerFixture, -) -> None: - mocker.patch("requests.post", side_effect=mocked_requests_post) +class GradientEmbeddingsModel(MagicMock): + """MockGradientModel.""" + def embed(self, inputs: List[Dict[str, str]]) -> Any: + """Just duplicate the query m times.""" + output = MagicMock() + + embeddings = [] + for i, inp in enumerate(inputs): + # verify correct ordering + inp = inp["input"] + if "pizza" in inp: + v = [1.0, 0.0, 0.0] + elif "document" in inp: + v = [0.0, 0.9, 0.0] + else: + v = [0.0, 0.0, -1.0] + if len(inp) > 10: + v[2] += 0.1 + output_inner = MagicMock() + output_inner.embedding = v + embeddings.append(output_inner) + + output.embeddings = embeddings + return output + + async def aembed(self, *args) -> Any: + return self.embed(*args) + + +class MockGradient(MagicMock): + """Mock Gradient package.""" + + def __init__(self, access_token: str, workspace_id, host): + assert access_token == _GRADIENT_SECRET + assert workspace_id == _GRADIENT_WORKSPACE_ID + assert host == _GRADIENT_BASE_URL + + def get_embeddings_model(self, slug: str) -> GradientEmbeddingsModel: + assert slug == _MODEL_ID + return GradientEmbeddingsModel() + + def close(self) -> None: + """Mock Gradient close.""" + return + + +class MockGradientaiPackage(MagicMock): + """Mock Gradientai package.""" + + Gradient = MockGradient + __version__ = "1.4.0" + + +def test_gradient_llm_sync() -> None: + with patch.dict(sys.modules, {"gradientai": MockGradientaiPackage()}): + embedder = GradientEmbeddings( + gradient_api_url=_GRADIENT_BASE_URL, + gradient_access_token=_GRADIENT_SECRET, + gradient_workspace_id=_GRADIENT_WORKSPACE_ID, + model=_MODEL_ID, + ) + assert embedder.gradient_access_token == _GRADIENT_SECRET + assert embedder.gradient_api_url == _GRADIENT_BASE_URL + assert embedder.gradient_workspace_id == _GRADIENT_WORKSPACE_ID + assert embedder.model == _MODEL_ID + + response = embedder.embed_documents(_DOCUMENTS) + want = [ + [1.0, 0.0, 0.0], # pizza + [1.0, 0.0, 0.1], # pizza + long + [0.0, 0.9, 0.0], # doc + [1.0, 0.0, 0.1], # pizza + long + [0.0, 0.9, 0.1], # doc + long + ] + + assert response == want + + +def test_gradient_wrong_setup() -> None: with pytest.raises(Exception): GradientEmbeddings( gradient_api_url=_GRADIENT_BASE_URL, @@ -130,6 +107,8 @@ def test_gradient_wrong_setup( model=_MODEL_ID, ) + +def test_gradient_wrong_setup2() -> None: with pytest.raises(Exception): GradientEmbeddings( gradient_api_url=_GRADIENT_BASE_URL, @@ -138,6 +117,8 @@ def test_gradient_wrong_setup( model=_MODEL_ID, ) + +def test_gradient_wrong_setup3() -> None: with pytest.raises(Exception): GradientEmbeddings( gradient_api_url="-", # empty diff --git a/libs/langchain/langchain/embeddings/gradient_ai.py b/libs/langchain/langchain/embeddings/gradient_ai.py index 386be5babb..b3866edeb1 100644 --- a/libs/langchain/langchain/embeddings/gradient_ai.py +++ b/libs/langchain/langchain/embeddings/gradient_ai.py @@ -1,6 +1,3 @@ -from langchain_community.embeddings.gradient_ai import ( - GradientEmbeddings, - TinyAsyncGradientEmbeddingClient, -) +from langchain_community.embeddings.gradient_ai import GradientEmbeddings -__all__ = ["GradientEmbeddings", "TinyAsyncGradientEmbeddingClient"] +__all__ = ["GradientEmbeddings"]