From 40de0e7f59f63491b29c6baa57130b9dcb66c1bd Mon Sep 17 00:00:00 2001 From: Laurel Orr <57237365+lorr1@users.noreply.github.com> Date: Sat, 8 Apr 2023 22:55:04 -0700 Subject: [PATCH] feat: openai embedding support (#75) --- README.md | 27 +++ examples/manifest_embedding.ipynb | 139 ++++++++++++ manifest/caches/cache.py | 41 +++- manifest/caches/serializers.py | 60 ++++++ manifest/clients/ai21.py | 2 +- manifest/clients/cohere.py | 2 +- manifest/clients/openai.py | 8 +- .../clients/{openaichat.py => openai_chat.py} | 10 +- manifest/clients/openai_embedding.py | 204 ++++++++++++++++++ manifest/clients/toma.py | 2 +- manifest/clients/toma_diffuser.py | 2 +- manifest/manifest.py | 4 +- manifest/request.py | 9 +- manifest/response.py | 4 + tests/test_cache.py | 77 ++++++- tests/test_manifest.py | 111 ++++++++++ tests/test_serializer.py | 17 +- 17 files changed, 687 insertions(+), 32 deletions(-) create mode 100644 examples/manifest_embedding.ipynb rename manifest/clients/{openaichat.py => openai_chat.py} (93%) create mode 100644 manifest/clients/openai_embedding.py diff --git a/README.md b/README.md index 0cd8baf..f39e08e 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ How to make prompt programming with Foundation Models a little easier. - [Getting Started](#getting-started) - [Manifest](#manifest-components) - [Local HuggingFace Models](#local-huggingface-models) +- [Embedding Models](#embedding-models) - [Development](#development) - [Cite](#cite) @@ -47,6 +48,9 @@ manifest = Manifest( manifest.run("Why is the grass green?") ``` +## Examples +We have example notebook and python scripts located at [examples](examples). These show how to use different models, model types (i.e. text, diffusers, or embedding models), and async running. + # Manifest Components Manifest is meant to be a very light weight package to help with prompt design and iteration. Three key design decisions of Manifest are @@ -112,6 +116,12 @@ You can also run over multiple examples if supported by the client. results = manifest.run(["Where are the cats?", "Where are the dogs?"]) ``` +We support async queries as well via +```python +import asyncio +results = asyncio.run(manifest.arun_batch(["Where are the cats?", "Where are the dogs?"])) +``` + If something doesn't go right, you can also ask to get a raw manifest Response. ```python result_object = manifest.run(["Where are the cats?", "Where are the dogs?"], return_response=True) @@ -178,6 +188,23 @@ python3 -m manifest.api.app \ --percent_max_gpu_mem_reduction 0.85 ``` +# Embedding Models +Manifest also supports getting embeddings from models and available APIs. We do this all through changing the `client_name` argument. You still use `run` and `abatch_run`. + +To use OpenAI's embedding models, simply run +```python +manifest = Manifest(client_name="openaiembedding") +embedding_as_np = manifest.run("Get me an embedding for a bunny") +``` + +As explained above, you can load local HuggingFace models that give you embeddings, too. If you want to use a standard generative model, load the model as above use use `client_name="huggingfaceembedding"`. If you want to use a standard embedding model, like those from SentenceTransformers, load your local model via +```bash +python3 -m manifest.api.app \ + --model_type sentence_transformers \ + --model_name_or_path all-mpnet-base-v2 \ + --device 0 +``` + # Development Before submitting a PR, run ```bash diff --git a/examples/manifest_embedding.ipynb b/examples/manifest_embedding.ipynb new file mode 100644 index 0000000..862a3fe --- /dev/null +++ b/examples/manifest_embedding.ipynb @@ -0,0 +1,139 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use OpenAI\n", + "\n", + "Set you `OPENAI_API_KEY` environment variable." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'model_name': 'openaiembedding', 'engine': 'text-embedding-ada-002'}\n" + ] + } + ], + "source": [ + "from manifest import Manifest\n", + "\n", + "manifest = Manifest(client_name=\"openaiembedding\")\n", + "print(manifest.client.get_model_params())" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1536,)\n" + ] + } + ], + "source": [ + "emb = manifest.run(\"Is this an embedding?\")\n", + "print(emb.shape)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using Locally Hosted Huggingface LM\n", + "\n", + "Run\n", + "```\n", + "python3 manifest/api/app.py --model_type huggingface --model_name_or_path EleutherAI/gpt-neo-125M --device 0\n", + "```\n", + "in a separate `screen` or `tmux`." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'model_name': 'EleutherAI/gpt-neo-125M', 'model_path': 'EleutherAI/gpt-neo-125M'}\n" + ] + } + ], + "source": [ + "from manifest import Manifest\n", + "\n", + "# Local hosted GPT Neo 125M\n", + "manifest = Manifest(\n", + " client_name=\"huggingfaceembedding\",\n", + " client_connection=\"http://127.0.0.1:6001\",\n", + " cache_name=\"sqlite\",\n", + " cache_connection=\"my_sqlite_manifest.sqlite\"\n", + ")\n", + "print(manifest.client.get_model_params())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "emb = manifest.run(\"Is this an embedding?\")\n", + "emb2 = manifest.run(\"Is this an embedding?\", aggregation=\"mean\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "manifest", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/manifest/caches/cache.py b/manifest/caches/cache.py index 2a16557..74df2fb 100644 --- a/manifest/caches/cache.py +++ b/manifest/caches/cache.py @@ -2,12 +2,14 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Union -from manifest.caches.serializers import ArraySerializer, Serializer +from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer, Serializer from manifest.response import RESPONSE_CONSTRUCTORS, Response -CACHE_CONSTRUCTOR = { - "diffuser": ArraySerializer, - "tomadiffuser": ArraySerializer, +# Non-text return type caches +ARRAY_CACHE_TYPES = { + "diffuser", + "tomadiffuser", + "openaiembedding", } @@ -21,17 +23,21 @@ class Cache(ABC): cache_args: Dict[str, Any] = {}, ): """ - Initialize client. + Initialize cache. Args: connection_str: connection string. client_name: name of client. cache_args: arguments for cache. - cache_args are passed to client as default parameters. + cache_args are any arguments needed to initialize the cache. - For clients like OpenAI that do not require a connection, - the connection_str can be None. + Further, cache_args can contain `array_serializer` as a string + for embedding or image return types (e.g. diffusers) with values + as `local_file` or `byte_string`. `local_file` will save the + array in a local file and cache a pointer to the file. + `byte_string` will convert the array to a byte string and cache + the entire byte string. `byte_string` is default. Args: connection_str: connection string for client. @@ -39,7 +45,22 @@ class Cache(ABC): """ self.client_name = client_name self.connect(connection_str, cache_args) - self.serializer = CACHE_CONSTRUCTOR.get(client_name, Serializer)() + if self.client_name in ARRAY_CACHE_TYPES: + array_serializer = cache_args.pop("array_serializer", "byte_string") + if array_serializer not in ["local_file", "byte_string"]: + raise ValueError( + "array_serializer must be local_file or byte_string," + f" not {array_serializer}" + ) + self.serializer = ( + ArraySerializer() + if array_serializer == "local_file" + else NumpyByteSerializer() + ) + else: + # If user has array_serializer type, it will throw an error as + # it is not recognized for non-array return types. + self.serializer = Serializer() @abstractmethod def close(self) -> None: @@ -107,7 +128,7 @@ class Cache(ABC): response, cached, request, - **RESPONSE_CONSTRUCTORS.get(self.client_name, {}) + **RESPONSE_CONSTRUCTORS.get(self.client_name, {}), ) return None diff --git a/manifest/caches/serializers.py b/manifest/caches/serializers.py index cf9fc93..d6ad506 100644 --- a/manifest/caches/serializers.py +++ b/manifest/caches/serializers.py @@ -1,10 +1,12 @@ """Serializer.""" +import io import json import os from pathlib import Path from typing import Dict +import numpy as np import xxhash from manifest.caches.array_cache import ArrayCache @@ -62,6 +64,64 @@ class Serializer: return json.loads(key) +class NumpyByteSerializer(Serializer): + """Serializer by casting array to byte string.""" + + def response_to_key(self, response: Dict) -> str: + """ + Normalize a response into a key. + + Args: + response: response to normalize. + + Returns: + normalized key. + """ + # Assume response is a dict with keys "choices" -> List dicts + # with keys "array". + choices = response["choices"] + # We don't want to modify the response in place + # but we want to avoid calling deepcopy on an array + del response["choices"] + response_copy = response.copy() + response["choices"] = choices + response_copy["choices"] = [] + for choice in choices: + if "array" not in choice: + raise ValueError( + f"Choice with keys {choice.keys()} does not have array key." + ) + arr = choice["array"] + # Avoid copying an array + del choice["array"] + new_choice = choice.copy() + choice["array"] = arr + with io.BytesIO() as f: + np.savez_compressed(f, data=arr) + hash_str = f.getvalue().hex() + new_choice["array"] = hash_str + response_copy["choices"].append(new_choice) + return json.dumps(response_copy, sort_keys=True) + + def key_to_response(self, key: str) -> Dict: + """ + Convert the normalized version to the response. + + Args: + key: normalized key to convert. + + Returns: + unnormalized response dict. + """ + response = json.loads(key) + for choice in response["choices"]: + hash_str = choice["array"] + byte_str = bytes.fromhex(hash_str) + with io.BytesIO(byte_str) as f: + choice["array"] = np.load(f)["data"] + return response + + class ArraySerializer(Serializer): """Serializer for array.""" diff --git a/manifest/clients/ai21.py b/manifest/clients/ai21.py index bf6f6cd..6bd356c 100644 --- a/manifest/clients/ai21.py +++ b/manifest/clients/ai21.py @@ -92,7 +92,7 @@ class AI21Client(Client): Returns: model params. """ - return {"model_name": "ai21", "engine": getattr(self, "engine")} + return {"model_name": self.NAME, "engine": getattr(self, "engine")} def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ diff --git a/manifest/clients/cohere.py b/manifest/clients/cohere.py index 92e602e..62b3d89 100644 --- a/manifest/clients/cohere.py +++ b/manifest/clients/cohere.py @@ -91,7 +91,7 @@ class CohereClient(Client): Returns: model params. """ - return {"model_name": "cohere", "engine": getattr(self, "engine")} + return {"model_name": self.NAME, "engine": getattr(self, "engine")} def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ diff --git a/manifest/clients/openai.py b/manifest/clients/openai.py index 1c92907..79720c1 100644 --- a/manifest/clients/openai.py +++ b/manifest/clients/openai.py @@ -1,12 +1,12 @@ """OpenAI client.""" import logging import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type import tiktoken from manifest.clients.client import Client -from manifest.request import LMRequest +from manifest.request import LMRequest, Request logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ class OpenAIClient(Client): "presence_penalty": ("presence_penalty", 0.0), "frequency_penalty": ("frequency_penalty", 0.0), } - REQUEST_CLS = LMRequest + REQUEST_CLS: Type[Request] = LMRequest NAME = "openai" def connect( @@ -103,7 +103,7 @@ class OpenAIClient(Client): Returns: model params. """ - return {"model_name": "openai", "engine": getattr(self, "engine")} + return {"model_name": self.NAME, "engine": getattr(self, "engine")} def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]: """Split usage into list of usages for each prompt.""" diff --git a/manifest/clients/openaichat.py b/manifest/clients/openai_chat.py similarity index 93% rename from manifest/clients/openaichat.py rename to manifest/clients/openai_chat.py index b6ebe36..b37a5d7 100644 --- a/manifest/clients/openaichat.py +++ b/manifest/clients/openai_chat.py @@ -77,7 +77,7 @@ class OpenAIChatClient(OpenAIClient): Returns: model params. """ - return {"model_name": "openaichat", "engine": getattr(self, "engine")} + return {"model_name": self.NAME, "engine": getattr(self, "engine")} def _format_request_for_chat(self, request_params: Dict[str, Any]) -> Dict: """Format request params for chat. @@ -99,8 +99,8 @@ class OpenAIChatClient(OpenAIClient): request_params["messages"] = messages return request_params - def _format_request_for_text(self, response_dict: Dict[str, Any]) -> Dict: - """Format response for text. + def _format_request_from_chat(self, response_dict: Dict[str, Any]) -> Dict: + """Format response for standard response from chat. Args: response_dict: response. @@ -131,7 +131,7 @@ class OpenAIChatClient(OpenAIClient): request_params = self._format_request_for_chat(request_params) response_dict = super()._run_completion(request_params, retry_timeout) # Reformat for text model - response_dict = self._format_request_for_text(response_dict) + response_dict = self._format_request_from_chat(response_dict) return response_dict async def _arun_completion( @@ -153,5 +153,5 @@ class OpenAIChatClient(OpenAIClient): request_params, retry_timeout, batch_size ) # Reformat for text model - response_dict = self._format_request_for_text(response_dict) + response_dict = self._format_request_from_chat(response_dict) return response_dict diff --git a/manifest/clients/openai_embedding.py b/manifest/clients/openai_embedding.py new file mode 100644 index 0000000..ff41b5a --- /dev/null +++ b/manifest/clients/openai_embedding.py @@ -0,0 +1,204 @@ +"""OpenAI client.""" +import copy +import logging +import os +from typing import Any, Dict, List, Optional + +import numpy as np +import tiktoken + +from manifest.clients.openai import OpenAIClient +from manifest.request import EmbeddingRequest + +logger = logging.getLogger(__name__) + +OPENAI_EMBEDDING_ENGINES = { + "text-embedding-ada-002", +} + + +class OpenAIEmbeddingClient(OpenAIClient): + """OpenAI client.""" + + # User param -> (client param, default value) + PARAMS = { + "engine": ("model", "text-embedding-ada-002"), + } + REQUEST_CLS = EmbeddingRequest + NAME = "openaiembedding" + + def connect( + self, + connection_str: Optional[str] = None, + client_args: Dict[str, Any] = {}, + ) -> None: + """ + Connect to the OpenAI server. + + connection_str is passed as default OPENAI_API_KEY if variable not set. + + Args: + connection_str: connection string. + client_args: client arguments. + """ + self.api_key = os.environ.get("OPENAI_API_KEY", connection_str) + if self.api_key is None: + raise ValueError( + "OpenAI API key not set. Set OPENAI_API_KEY environment " + "variable or pass through `client_connection`." + ) + self.host = "https://api.openai.com/v1" + for key in self.PARAMS: + setattr(self, key, client_args.pop(key, self.PARAMS[key][1])) + if getattr(self, "engine") not in OPENAI_EMBEDDING_ENGINES: + raise ValueError( + f"Invalid engine {getattr(self, 'engine')}. " + f"Must be {OPENAI_EMBEDDING_ENGINES}." + ) + + def get_generation_url(self) -> str: + """Get generation URL.""" + return self.host + "/embeddings" + + def supports_batch_inference(self) -> bool: + """Return whether the client supports batch inference.""" + return True + + def get_model_params(self) -> Dict: + """ + Get model params. + + By getting model params from the server, we can add to request + and make sure cache keys are unique to model. + + Returns: + model params. + """ + return {"model_name": self.NAME, "engine": getattr(self, "engine")} + + def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: + """ + Format response to dict. + + Args: + response: response + request: request + + Return: + response as dict + """ + if "data" not in response: + raise ValueError(f"Invalid response: {response}") + if "usage" in response: + # Handle splitting the usages for batch requests + if len(response["data"]) == 1: + if isinstance(response["usage"], list): + response["usage"] = response["usage"][0] + response["usage"] = [response["usage"]] + else: + # Try to split usage + split_usage = self.split_usage(request, response["data"]) + if split_usage: + response["usage"] = split_usage + return response + + def _format_request_for_embedding(self, request_params: Dict[str, Any]) -> Dict: + """Format request params for embedding. + + Args: + request_params: request params. + + Returns: + formatted request params. + """ + # Format for embedding model + request_params = copy.deepcopy(request_params) + prompt = request_params.pop("prompt") + if isinstance(prompt, str): + prompt_list = [prompt] + else: + prompt_list = prompt + request_params["input"] = prompt_list + return request_params + + def _format_request_from_embedding(self, response_dict: Dict[str, Any]) -> Dict: + """Format response from embedding for standard response. + + Args: + response_dict: response. + + Return: + formatted response. + """ + new_choices = [] + response_dict = copy.deepcopy(response_dict) + for res in response_dict.pop("data"): + new_choices.append({"array": np.array(res["embedding"])}) + response_dict["choices"] = new_choices + return response_dict + + def _run_completion( + self, request_params: Dict[str, Any], retry_timeout: int + ) -> Dict: + """Execute completion request. + + Args: + request_params: request params. + retry_timeout: retry timeout. + + Returns: + response as dict. + """ + # Format for embedding model + request_params = self._format_request_for_embedding(request_params) + response_dict = super()._run_completion(request_params, retry_timeout) + # Reformat for text model + response_dict = self._format_request_from_embedding(response_dict) + return response_dict + + async def _arun_completion( + self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int + ) -> Dict: + """Async execute completion request. + + Args: + request_params: request params. + retry_timeout: retry timeout. + batch_size: batch size for requests. + + Returns: + response as dict. + """ + # Format for embedding model + request_params = self._format_request_for_embedding(request_params) + response_dict = await super()._arun_completion( + request_params, retry_timeout, batch_size + ) + # Reformat for text model + response_dict = self._format_request_from_embedding(response_dict) + return response_dict + + def split_usage(self, request: Dict, choices: List[str]) -> List[Dict[str, int]]: + """Split usage into list of usages for each prompt.""" + try: + encoding = tiktoken.encoding_for_model(getattr(self, "engine")) + except Exception: + return [] + prompt = request["input"] + if isinstance(prompt, str): + prompts = [prompt] + else: + prompts = prompt + assert len(prompts) == len(choices) + usages = [] + for pmt in prompts: + pmt_tokens = len(encoding.encode(pmt)) + # No completion tokens for embedding models + chc_tokens = 0 + usage = { + "prompt_tokens": pmt_tokens, + "completion_tokens": chc_tokens, + "total_tokens": pmt_tokens + chc_tokens, + } + usages.append(usage) + return usages diff --git a/manifest/clients/toma.py b/manifest/clients/toma.py index 4147b45..259a310 100644 --- a/manifest/clients/toma.py +++ b/manifest/clients/toma.py @@ -121,7 +121,7 @@ class TOMAClient(Client): Returns: model params. """ - return {"model_name": "toma", "engine": getattr(self, "engine")} + return {"model_name": self.NAME, "engine": getattr(self, "engine")} def get_model_heartbeats(self) -> Dict[str, Dict]: """ diff --git a/manifest/clients/toma_diffuser.py b/manifest/clients/toma_diffuser.py index 1f91561..c38c46a 100644 --- a/manifest/clients/toma_diffuser.py +++ b/manifest/clients/toma_diffuser.py @@ -44,7 +44,7 @@ class TOMADiffuserClient(TOMAClient): Returns: model params. """ - return {"model_name": "tomadiffuser", "engine": getattr(self, "engine")} + return {"model_name": self.NAME, "engine": getattr(self, "engine")} def format_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ diff --git a/manifest/manifest.py b/manifest/manifest.py index 2c04eb0..0717c61 100644 --- a/manifest/manifest.py +++ b/manifest/manifest.py @@ -14,7 +14,8 @@ from manifest.clients.cohere import CohereClient from manifest.clients.dummy import DummyClient from manifest.clients.huggingface import HuggingFaceClient from manifest.clients.openai import OpenAIClient -from manifest.clients.openaichat import OpenAIChatClient +from manifest.clients.openai_chat import OpenAIChatClient +from manifest.clients.openai_embedding import OpenAIEmbeddingClient from manifest.clients.toma import TOMAClient from manifest.request import Request from manifest.response import Response @@ -25,6 +26,7 @@ logger = logging.getLogger(__name__) CLIENT_CONSTRUCTORS = { OpenAIClient.NAME: OpenAIClient, OpenAIChatClient.NAME: OpenAIChatClient, + OpenAIEmbeddingClient.NAME: OpenAIEmbeddingClient, CohereClient.NAME: CohereClient, AI21Client.NAME: AI21Client, HuggingFaceClient.NAME: HuggingFaceClient, diff --git a/manifest/request.py b/manifest/request.py index 27dc57b..a6c31de 100644 --- a/manifest/request.py +++ b/manifest/request.py @@ -1,5 +1,5 @@ """Request object.""" -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union from pydantic import BaseModel @@ -96,6 +96,13 @@ class LMRequest(Request): frequency_penalty: float = 0 +class EmbeddingRequest(Request): + """Embedding Request object.""" + + # Aggregate method (if applicable) + aggregation_method: Optional[Literal["last_token", "mean"]] = None + + class DiffusionRequest(Request): """Diffusion Model Request object.""" diff --git a/manifest/response.py b/manifest/response.py index 90dfcc4..2b9a1ef 100644 --- a/manifest/response.py +++ b/manifest/response.py @@ -13,6 +13,10 @@ RESPONSE_CONSTRUCTORS = { "logits_key": "token_logprobs", "item_key": "array", }, + "openaiembedding": { + "logits_key": "token_logprobs", + "item_key": "array", + }, } diff --git a/tests/test_cache.py b/tests/test_cache.py index df74929..3821a90 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,5 +1,5 @@ """Cache test.""" -from typing import cast +from typing import Dict, cast import numpy as np import pytest @@ -13,12 +13,15 @@ from manifest.caches.redis import RedisCache from manifest.caches.sqlite import SQLiteCache -def _get_postgres_cache(**kwargs) -> Cache: # type: ignore +def _get_postgres_cache( + client_name: str = "", cache_args: Dict = {} +) -> Cache: # type: ignore """Get postgres cache.""" + cache_args.update({"cache_user": "", "cache_password": "", "cache_db": ""}) return PostgresCache( "postgres", - cache_args={"cache_user": "", "cache_password": "", "cache_db": ""}, - **kwargs, + client_name=client_name, + cache_args=cache_args, ) @@ -96,11 +99,11 @@ def test_get( assert response.is_cached() assert response.get_request() == test_request + # Test array arr = np.random.rand(4, 4) test_request = {"test": "hello", "testA": "world of images"} compute_arr_response = {"choices": [{"array": arr}]} - # Test array if cache_type == "sqlite": cache = SQLiteCache(sqlite_cache, client_name="diffuser") elif cache_type == "redis": @@ -117,6 +120,37 @@ def test_get( assert response.is_cached() assert response.get_request() == test_request + # Test array byte string + arr = np.random.rand(4, 4) + test_request = {"test": "hello", "testA": "world of images 2"} + compute_arr_response = {"choices": [{"array": arr}]} + + if cache_type == "sqlite": + cache = SQLiteCache( + sqlite_cache, + client_name="diffuser", + cache_args={"array_serializer": "byte_string"}, + ) + elif cache_type == "redis": + cache = RedisCache( + redis_cache, + client_name="diffuser", + cache_args={"array_serializer": "byte_string"}, + ) + elif cache_type == "postgres": + cache = _get_postgres_cache( + client_name="diffuser", cache_args={"array_serializer": "byte_string"} + ) + + response = cache.get(test_request) + assert response is None + + cache.set(test_request, compute_arr_response) + response = cache.get(test_request) + assert np.allclose(response.get_response(), arr) + assert response.is_cached() + assert response.get_request() == test_request + @pytest.mark.usefixtures("sqlite_cache") @pytest.mark.usefixtures("redis_cache") @@ -168,6 +202,39 @@ def test_get_batch_prompt( assert response.is_cached() assert response.get_request() == test_request + # Test arrays byte serializer + arr = np.random.rand(4, 4) + arr2 = np.random.rand(4, 4) + test_request = {"test": ["hello", "goodbye"], "testA": "world of images 2"} + compute_arr_response = {"choices": [{"array": arr}, {"array": arr2}]} + + if cache_type == "sqlite": + cache = SQLiteCache( + sqlite_cache, + client_name="diffuser", + cache_args={"array_serializer": "byte_string"}, + ) + elif cache_type == "redis": + cache = RedisCache( + redis_cache, + client_name="diffuser", + cache_args={"array_serializer": "byte_string"}, + ) + elif cache_type == "postgres": + cache = _get_postgres_cache( + client_name="diffuser", cache_args={"array_serializer": "byte_string"} + ) + + response = cache.get(test_request) + assert response is None + + cache.set(test_request, compute_arr_response) + response = cache.get(test_request) + assert np.allclose(response.get_response()[0], arr) + assert np.allclose(response.get_response()[1], arr2) + assert response.is_cached() + assert response.get_request() == test_request + def test_noop_cache() -> None: """Test cache that is a no-op cache.""" diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 49c197e..b6b99ba 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -4,6 +4,7 @@ import os from typing import cast from unittest.mock import MagicMock, Mock, patch +import numpy as np import pytest import requests from requests import HTTPError @@ -567,6 +568,7 @@ def test_openai(sqlite_cache: str) -> None: response = cast(Response, client.run("Why are there apples?", return_response=True)) assert isinstance(response.get_response(), str) and len(response.get_response()) > 0 + assert response.get_response() == res assert response.is_cached() is True assert "usage" in response.get_json_response() assert response.get_json_response()["usage"][0]["total_tokens"] == 15 @@ -643,6 +645,7 @@ def test_openaichat(sqlite_cache: str) -> None: response = cast(Response, client.run("Why are there apples?", return_response=True)) assert isinstance(response.get_response(), str) and len(response.get_response()) > 0 + assert response.get_response() == res assert response.is_cached() is True assert "usage" in response.get_json_response() assert response.get_json_response()["usage"][0]["total_tokens"] == 23 @@ -685,6 +688,114 @@ def test_openaichat(sqlite_cache: str) -> None: assert response.is_cached() is True +@pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set") +@pytest.mark.usefixtures("sqlite_cache") +def test_openaiembedding(sqlite_cache: str) -> None: + """Test openaichat client.""" + client = Manifest( + client_name="openaiembedding", + cache_name="sqlite", + cache_connection=sqlite_cache, + array_serializer="local_file", + ) + + res = client.run("Why are there carrots?") + assert isinstance(res, np.ndarray) + + response = cast( + Response, client.run("Why are there carrots?", return_response=True) + ) + assert isinstance(response.get_response(), np.ndarray) + assert np.allclose(response.get_response(), res) + + client = Manifest( + client_name="openaiembedding", + cache_name="sqlite", + cache_connection=sqlite_cache, + ) + + res = client.run("Why are there apples?") + assert isinstance(res, np.ndarray) + + response = cast(Response, client.run("Why are there apples?", return_response=True)) + assert isinstance(response.get_response(), np.ndarray) + assert np.allclose(response.get_response(), res) + assert response.is_cached() is True + assert "usage" in response.get_json_response() + assert response.get_json_response()["usage"][0]["total_tokens"] == 5 + + response = cast(Response, client.run("Why are there apples?", return_response=True)) + assert response.is_cached() is True + + res_list = client.run(["Why are there apples?", "Why are there bananas?"]) + assert ( + isinstance(res_list, list) + and len(res_list) == 2 + and isinstance(res_list[0], np.ndarray) + ) + + response = cast( + Response, + client.run( + ["Why are there apples?", "Why are there mangos?"], return_response=True + ), + ) + assert ( + isinstance(response.get_response(), list) and len(response.get_response()) == 2 + ) + assert ( + "usage" in response.get_json_response() + and len(response.get_json_response()["usage"]) == 2 + ) + assert response.get_json_response()["usage"][0]["total_tokens"] == 5 + assert response.get_json_response()["usage"][1]["total_tokens"] == 6 + + response = cast( + Response, client.run("Why are there bananas?", return_response=True) + ) + assert response.is_cached() is True + + response = cast( + Response, client.run("Why are there oranges?", return_response=True) + ) + assert response.is_cached() is False + + res_list = asyncio.run( + client.arun_batch(["Why are there pears?", "Why are there oranges?"]) + ) + assert ( + isinstance(res_list, list) + and len(res_list) == 2 + and isinstance(res_list[0], np.ndarray) + ) + + response = cast( + Response, + asyncio.run( + client.arun_batch( + ["Why are there pinenuts?", "Why are there cocoa?"], + return_response=True, + ) + ), + ) + assert ( + isinstance(response.get_response(), list) + and len(res_list) == 2 + and isinstance(res_list[0], np.ndarray) + ) + assert ( + "usage" in response.get_json_response() + and len(response.get_json_response()["usage"]) == 2 + ) + assert response.get_json_response()["usage"][0]["total_tokens"] == 7 + assert response.get_json_response()["usage"][1]["total_tokens"] == 5 + + response = cast( + Response, client.run("Why are there oranges?", return_response=True) + ) + assert response.is_cached() is True + + def test_retry_handling() -> None: """Test retry handling.""" # We'll mock the response so we won't need a real connection diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 8f4e4a0..bbc269f 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -3,10 +3,10 @@ import json import numpy as np -from manifest.caches.serializers import ArraySerializer +from manifest.caches.serializers import ArraySerializer, NumpyByteSerializer -def test_response_to_key() -> None: +def test_response_to_key_array() -> None: """Test array serializer initialization.""" serializer = ArraySerializer() arr = np.random.rand(4, 4) @@ -17,3 +17,16 @@ def test_response_to_key() -> None: res2 = serializer.key_to_response(key) assert np.allclose(arr, res2["choices"][0]["array"]) + + +def test_response_to_key_numpybytes() -> None: + """Test array serializer initialization.""" + serializer = NumpyByteSerializer() + arr = np.random.rand(4, 4) + res = {"choices": [{"array": arr}]} + key = serializer.response_to_key(res) + key_dct = json.loads(key) + assert isinstance(key_dct["choices"][0]["array"], str) + + res2 = serializer.key_to_response(key) + assert np.allclose(arr, res2["choices"][0]["array"])