From 6324e0fe439723cfc5a9238551562d9f4287e39c Mon Sep 17 00:00:00 2001 From: Laurel Orr <57237365+lorr1@users.noreply.github.com> Date: Sun, 21 May 2023 15:50:03 -0700 Subject: [PATCH] feat: streaming support completions (#99) --- CHANGELOG.rst | 1 + README.md | 34 +++-- examples/manifest_streaming.ipynb | 105 ++++++++++++++++ manifest/clients/ai21.py | 7 ++ manifest/clients/client.py | 124 +++++++++++++++++- manifest/clients/cohere.py | 7 ++ manifest/clients/diffuser.py | 7 ++ manifest/clients/dummy.py | 7 ++ manifest/clients/google.py | 7 ++ manifest/clients/huggingface.py | 7 ++ manifest/clients/huggingface_embedding.py | 7 ++ manifest/clients/openai.py | 7 ++ manifest/clients/openai_chat.py | 7 +- manifest/clients/openai_embedding.py | 7 ++ manifest/clients/toma.py | 7 ++ manifest/manifest.py | 137 +++++++++++++++++--- manifest/response.py | 145 +++++++++++++++++++--- tests/conftest.py | 10 +- tests/test_manifest.py | 83 ++++++++++++- tests/test_response.py | 94 +++++++++++++- 20 files changed, 754 insertions(+), 56 deletions(-) create mode 100644 examples/manifest_streaming.ipynb diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b8663f1..1dd8021 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,7 @@ Added ^^^^^ * Azure model support (completion and chat) * Google Vertex API model support (completion and chat) +* Streaming responses for LM Completions (set stream=True) Fixed ^^^^^ diff --git a/README.md b/README.md index 6fe934c..34e7715 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,10 @@ How to make prompt programming with Foundation Models a little easier. - [Install](#install) - [Getting Started](#getting-started) - [Manifest](#manifest-components) -- [Local HuggingFace Models](#local-huggingface-models) -- [Chat Models](#chat-models) -- [Embedding Models](#embedding-models) +- [Other Models Types](#other-models) + - [Local HuggingFace Models](#local-huggingface-models) + - [Chat Models](#chat-models) + - [Embedding Models](#embedding-models) - [Road Map](#road-map) - [Development](#development) - [Cite](#cite) @@ -43,7 +44,7 @@ Running is simple to get started. If using OpenAI, set `export OPENAI_API_KEY= 80:\n", + " sys.stdout.write(\"\\n\")\n", + " cur_line_length = 0" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "manifest", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "fddffe4ac3b9f00470127629076101c1b5f38ecb1e7358b567d19305425e9491" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/manifest/clients/ai21.py b/manifest/clients/ai21.py index 6cb4004..03a8669 100644 --- a/manifest/clients/ai21.py +++ b/manifest/clients/ai21.py @@ -82,6 +82,13 @@ class AI21Client(Client): """Return whether the client supports batch inference.""" return False + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ + return False + def get_model_params(self) -> Dict: """ Get model params. diff --git a/manifest/clients/client.py b/manifest/clients/client.py index a034e36..eb342ef 100644 --- a/manifest/clients/client.py +++ b/manifest/clients/client.py @@ -1,10 +1,11 @@ """Client class.""" import asyncio import copy +import json import logging import math from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, Generator, List, Optional, Tuple, Union, cast import aiohttp import requests @@ -107,6 +108,7 @@ class Client(ABC): """ Connect to client. + Override in child client class. Args: connection_str: connection string. """ @@ -114,12 +116,18 @@ class Client(ABC): @abstractmethod def close(self) -> None: - """Close the client.""" + """Close the client. + + Override in child client class. + """ raise NotImplementedError() @abstractmethod def get_generation_url(self) -> str: - """Get generation URL.""" + """Get generation URL. + + Override in child client class. + """ raise NotImplementedError() @abstractmethod @@ -127,6 +135,7 @@ class Client(ABC): """ Get generation header. + Override in child client class. Returns: header. """ @@ -134,7 +143,18 @@ class Client(ABC): @abstractmethod def supports_batch_inference(self) -> bool: - """Return whether the client supports batch inference.""" + """Return whether the client supports batch inference. + + Override in child client class. + """ + raise NotImplementedError() + + @abstractmethod + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ raise NotImplementedError() @abstractmethod @@ -145,6 +165,7 @@ class Client(ABC): By getting model params from the server, we can add to request and make sure cache keys are unique to model. + Override in child client class. Returns: model params. """ @@ -153,6 +174,8 @@ class Client(ABC): def get_tokenizer(self, model: str) -> Tuple[Any, int]: """Get tokenizer for model. + Override in child client class. Return None, -1 if not supported + or no prompt truncation required. Returns: tokenizer: tokenizer with encoder and decode max_length: max length of model @@ -177,6 +200,8 @@ class Client(ABC): """ Preprocess request params. + Override in child client class to reformat requests to model. + Args: request: request params. @@ -191,6 +216,8 @@ class Client(ABC): """ Postprocess and validate response as dict. + Override in child client class to reform model responses. + Args: response: response request: request @@ -314,6 +341,7 @@ class Client(ABC): final_usages = None if usages: final_usages = Usages(usages=[Usage(**usage) for usage in usages]) + # TODO: Add usage based on tokenizer return Response( self._get_model_choices(final_response_dict), cached=False, @@ -415,6 +443,55 @@ class Client(ABC): res_json = await res.json(content_type=None) return self.postprocess_response(res_json, request_params) + @retry( + reraise=True, + retry=retry_if_ratelimit, + wait=wait_random_exponential(min=1, max=ATTEMPTS_TIMEOUT), + stop=stop_after_attempt(ATTEMPTS_BEFORE_STOP), + ) + def _run_streaming_completion( + self, request_params: Dict[str, Any], retry_timeout: int + ) -> Generator[Dict, None, None]: + """Execute completion request streaming. + + Args: + request_params: request params. + retry_timeout: retry timeout. + + Returns: + response as dict. + """ + request_params = self.preprocess_request_params(request_params) + request_params["stream"] = True + post_str = self.get_generation_url() + res_iter = requests.post( + post_str, + headers=self.get_generation_header(), + json=request_params, + timeout=retry_timeout, + stream=True, + ) + for res_token in res_iter.iter_lines(): + if res_token: + decoded_res_token = res_token.decode("utf-8") + decoded_res_token = decoded_res_token.replace("data: ", "") + if decoded_res_token == "[DONE]": + break + try: + decoded_res_token_dct = json.loads(decoded_res_token) + postprocess_res_token_dct = self.postprocess_response( + decoded_res_token_dct, request_params + ) + # If nothing is returned, skip + if ( + not postprocess_res_token_dct + or not postprocess_res_token_dct["choices"] + ): + continue + yield postprocess_res_token_dct + except Exception as e: + raise e + def run_request(self, request: Request) -> Response: """ Run request. @@ -563,6 +640,45 @@ class Client(ABC): **RESPONSE_CONSTRUCTORS[LMChatRequest], # type: ignore ) + def run_streaming_request( + self, request: Request + ) -> Generator[Response, None, None]: + """ + Run streaming request. + + Args: + request: request. + + Returns: + response. + """ + if not isinstance(request.prompt, str): + raise ValueError("Streaming requests must have a single prompt.") + if not self.supports_streaming_inference(): + raise ValueError( + f"{self.__class__.__name__} does not support streaming inference." + ) + request_params = self._get_request_params(request) + + # Take the default keys we need and drop the rest as they + # are not part of the model request. + retry_timeout = request_params.pop("client_timeout") + for key in DEFAULT_REQUEST_KEYS: + request_params.pop(key, None) + + # Make sure requests are in the request length + # If no tokenizer is set or not LM request, this + # will do nothing + if isinstance(request, LMRequest): + self._verify_request_lengths( + request_params, model=request.engine, max_tokens=request.max_tokens + ) + + for token_response in self._run_streaming_completion( + request_params, retry_timeout + ): + yield self._stitch_responses(request, [token_response]) + def run_score_prompt_request( self, request: LMScoreRequest, diff --git a/manifest/clients/cohere.py b/manifest/clients/cohere.py index aa536d8..b2192ec 100644 --- a/manifest/clients/cohere.py +++ b/manifest/clients/cohere.py @@ -81,6 +81,13 @@ class CohereClient(Client): """Return whether the client supports batch inference.""" return False + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ + return False + def get_model_params(self) -> Dict: """ Get model params. diff --git a/manifest/clients/diffuser.py b/manifest/clients/diffuser.py index d3ca4e3..6419a38 100644 --- a/manifest/clients/diffuser.py +++ b/manifest/clients/diffuser.py @@ -72,6 +72,13 @@ class DiffuserClient(Client): """Return whether the client supports batch inference.""" return True + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ + return False + def get_model_params(self) -> Dict: """ Get model params. diff --git a/manifest/clients/dummy.py b/manifest/clients/dummy.py index 0a760f8..3a15577 100644 --- a/manifest/clients/dummy.py +++ b/manifest/clients/dummy.py @@ -48,6 +48,13 @@ class DummyClient(Client): """Return whether the client supports batch inference.""" return True + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ + return False + def get_generation_header(self) -> Dict[str, str]: """ Get generation header. diff --git a/manifest/clients/google.py b/manifest/clients/google.py index f0c1fed..83c2658 100644 --- a/manifest/clients/google.py +++ b/manifest/clients/google.py @@ -117,6 +117,13 @@ class GoogleClient(Client): """Return whether the client supports batch inference.""" return True + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ + return False + def get_model_params(self) -> Dict: """ Get model params. diff --git a/manifest/clients/huggingface.py b/manifest/clients/huggingface.py index 4ad1451..9a61027 100644 --- a/manifest/clients/huggingface.py +++ b/manifest/clients/huggingface.py @@ -66,6 +66,13 @@ class HuggingFaceClient(Client): """Return whether the client supports batch inference.""" return True + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ + return False + def get_model_params(self) -> Dict: """ Get model params. diff --git a/manifest/clients/huggingface_embedding.py b/manifest/clients/huggingface_embedding.py index 7ba6749..02b3d81 100644 --- a/manifest/clients/huggingface_embedding.py +++ b/manifest/clients/huggingface_embedding.py @@ -58,6 +58,13 @@ class HuggingFaceEmbeddingClient(Client): """Return whether the client supports batch inference.""" return True + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ + return False + def get_model_params(self) -> Dict: """ Get model params. diff --git a/manifest/clients/openai.py b/manifest/clients/openai.py index ac0986d..6e9ef00 100644 --- a/manifest/clients/openai.py +++ b/manifest/clients/openai.py @@ -95,6 +95,13 @@ class OpenAIClient(Client): """Return whether the client supports batch inference.""" return True + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ + return True + def get_model_params(self) -> Dict: """ Get model params. diff --git a/manifest/clients/openai_chat.py b/manifest/clients/openai_chat.py index 6580251..00e49d0 100644 --- a/manifest/clients/openai_chat.py +++ b/manifest/clients/openai_chat.py @@ -129,6 +129,11 @@ class OpenAIChatClient(OpenAIClient): new_choices = [] response = copy.deepcopy(response) for message in response["choices"]: - new_choices.append({"text": message["message"]["content"]}) + if "delta" in message: + # This is a streaming response + if "content" in message["delta"]: + new_choices.append({"text": message["delta"]["content"]}) + else: + new_choices.append({"text": message["message"]["content"]}) response["choices"] = new_choices return super().postprocess_response(response, request) diff --git a/manifest/clients/openai_embedding.py b/manifest/clients/openai_embedding.py index 41d893e..27c116f 100644 --- a/manifest/clients/openai_embedding.py +++ b/manifest/clients/openai_embedding.py @@ -76,6 +76,13 @@ class OpenAIEmbeddingClient(OpenAIClient): """ return {"model_name": self.NAME, "engine": getattr(self, "engine")} + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ + return False + def postprocess_response(self, response: Dict, request: Dict) -> Dict[str, Any]: """ Format response to dict. diff --git a/manifest/clients/toma.py b/manifest/clients/toma.py index 47f0d2a..417db97 100644 --- a/manifest/clients/toma.py +++ b/manifest/clients/toma.py @@ -111,6 +111,13 @@ class TOMAClient(Client): """Return whether the client supports batch inference.""" return False + def supports_streaming_inference(self) -> bool: + """Return whether the client supports streaming inference. + + Override in child client class. + """ + return False + def get_model_params(self) -> Dict: """ Get model params. diff --git a/manifest/manifest.py b/manifest/manifest.py index 52ee901..45413c4 100644 --- a/manifest/manifest.py +++ b/manifest/manifest.py @@ -2,7 +2,18 @@ import asyncio import copy import logging -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast +from typing import ( + Any, + Dict, + Generator, + Iterator, + List, + Optional, + Tuple, + Type, + Union, + cast, +) import numpy as np @@ -291,8 +302,17 @@ class Manifest: overwrite_cache: bool = False, stop_token: Optional[str] = None, return_response: bool = False, + stream: bool = False, **kwargs: Any, - ) -> Union[str, List[str], np.ndarray, List[np.ndarray], Response]: + ) -> Union[ + str, + List[str], + np.ndarray, + List[np.ndarray], + Response, + Iterator[str], + Iterator[Response], + ]: """ Run the prompt. @@ -302,9 +322,11 @@ class Manifest: prompt: prompt(s) to run. overwrite_cache: whether to overwrite cache. stop_token: stop token for prompt generation. - Default is self.stop_token. - "" for no stop token. + Default is self.stop_token. + "" for no stop token. return_response: whether to return Response object. + stream: whether to stream the prompt. Only supported + for single string prompts and LMs. Returns: response from prompt. @@ -319,6 +341,24 @@ class Manifest: raise ValueError("Prompt cannot be empty list") # Get the client to run client = self.client_pool.get_next_client() + if stream: + if not client.supports_streaming_inference(): + raise ValueError( + f"Client {client} does not support streaming inference." + ) + if not isinstance(prompt, str): + raise ValueError( + "Stream is only supported for single string prompts. " + "It will soon be supported for chat dictionary prompts, too." + ) + return self._run_stream( + prompt=cast(str, prompt), + client=client, + overwrite_cache=overwrite_cache, + stop_token=stop_token, + return_response=return_response, + **kwargs, + ) if isinstance(prompt, list) and isinstance(prompt[0], dict): if not client.IS_CHAT: raise ValueError( @@ -337,15 +377,14 @@ class Manifest: return_response=return_response, **kwargs, ) - else: - return self._run( - prompt=cast(Union[str, List[str]], prompt), - client=client, - overwrite_cache=overwrite_cache, - stop_token=stop_token, - return_response=return_response, - **kwargs, - ) + return self._run( + prompt=cast(Union[str, List[str]], prompt), + client=client, + overwrite_cache=overwrite_cache, + stop_token=stop_token, + return_response=return_response, + **kwargs, + ) def _run( self, @@ -399,7 +438,6 @@ class Manifest: response=response, cached_idx_to_response=cached_idx_to_response, ) - # Extract text results if return_response: return final_response @@ -467,6 +505,77 @@ class Manifest: else: return cast(str, final_response.get_response("", is_batch)) + def _run_stream( + self, + prompt: str, + client: Client, + overwrite_cache: bool = False, + stop_token: Optional[str] = None, + return_response: bool = False, + **kwargs: Any, + ) -> Union[Generator[str, None, None], Generator[Response, None, None]]: + """ + Run the prompt in a stream. + + Args: + prompt: prompt(s) to run. + client: client to run. + overwrite_cache: whether to overwrite cache. + stop_token: stop token for prompt generation. + Default is self.stop_token. + "" for no stop token. + return_response: whether to return Response object. + + Returns: + response from prompt. + """ + is_batch = False + stop_token = stop_token if stop_token is not None else self.stop_token + # Must pass kwargs as dict for client "pop" methods removed used arguments + request_params = client.get_request(prompt, kwargs) + # Avoid nested list of results - enforce n = 1 for batch + if request_params.n > 1: + raise ValueError("Stream mode does not support n > 1.") + self._validate_kwargs(kwargs, request_params) + + cached_idx_to_response, request_params = self._split_cached_requests( + request_params, client, overwrite_cache + ) + if request_params.prompt: + # Because we are streaming, we should have either a cached response + # a prompt to run + assert len(cached_idx_to_response) == 0 + response_iter = client.run_streaming_request(request_params) + is_cached = False + else: + assert len(cached_idx_to_response) == 1 + response_iter = cached_idx_to_response[0].as_iter() + is_cached = True + + saved_responses = [] + # Start timing metrics + self.client_pool.start_timer() + for response_token in response_iter: + saved_responses.append(response_token) + if return_response: + yield response_token + else: + yield cast( + Union[str, Response], response_token.get_response("", is_batch) + ) + self.client_pool.end_timer() + + if not is_cached: + final_response = Response.union_all( + saved_responses, as_single_lmchoice=True + ) + self._stitch_responses_and_cache( + request=request_params, + client=client, + response=final_response, + cached_idx_to_response=cached_idx_to_response, + ) + async def arun_batch( self, prompts: List[str], diff --git a/manifest/response.py b/manifest/response.py index 3b9498b..7760b8b 100644 --- a/manifest/response.py +++ b/manifest/response.py @@ -1,7 +1,7 @@ """Client response.""" import copy import json -from typing import Any, Dict, List, Optional, Type, Union, cast +from typing import Any, Dict, Generator, List, Optional, Type, Union, cast import numpy as np from pydantic import BaseModel @@ -154,9 +154,7 @@ class Response: stop_token: stop token for string generation is_batch: whether response is batched """ - process_result = ( - lambda x: x.strip().split(stop_token)[0] if stop_token else x.strip() - ) + process_result = lambda x: x.split(stop_token)[0] if stop_token else x extracted_items = [ choice.text if isinstance(choice, LMModelChoice) else choice.array for choice in self._response.choices @@ -173,8 +171,17 @@ class Response: return processed_results @classmethod - def union_all(cls, responses: List["Response"]) -> "Response": - """Union a list of response.""" + def union_all( + cls, responses: List["Response"], as_single_lmchoice: bool = False + ) -> "Response": + """Union a list of response. + + Args: + responses: list of responses to union. + as_single_lmchoice: if True, will concatenate all responses into a single + model choice. Useful for merging streaming responses. Only valid + for LMRequest responses. + """ if not responses: raise ValueError("Response list is empty.") if len(responses) == 1: @@ -184,6 +191,9 @@ class Response: response_type = first_response._response_type request = first_response.get_request_obj() + if as_single_lmchoice and response_type != "text": + raise ValueError("as_single_lmchoice=True only works for text responses.") + # Make sure all responses have the same keys if not all( [ @@ -197,7 +207,7 @@ class Response: # Get all the prompts and model choices all_prompts = [] all_choices = [] - all_usages = [] + all_usages: List[Usage] = [] all_engines = [] for res in responses: all_engines.extend(res.get_request_obj().engine.split(ENGINE_SEP)) @@ -213,18 +223,115 @@ class Response: all_usages.extend([Usage()] * len(res_prompt)) new_request = copy.deepcopy(request) new_request.engine = ENGINE_SEP.join(sorted(set(all_engines))) - new_request.prompt = all_prompts - new_response = ModelChoices(choices=all_choices) - new_usages = Usages(usages=all_usages) - response_obj = cls( - response=new_response, - cached=any(res.is_cached() for res in responses), - request=new_request, - usages=new_usages, - request_type=request_type, - response_type=response_type, - ) - return response_obj + + if as_single_lmchoice: + if len(set(all_prompts)) != 1: + raise ValueError("Prompts must be the same for as_single_lmchoice=True") + all_choices_txt = cast(List[LMModelChoice], all_choices) # type: ignore + single_prompt = all_prompts[0] + single_text = "".join([choice.text for choice in all_choices_txt]) + single_logprobs = [ + logprob + for choice in all_choices_txt + for logprob in choice.token_logprobs or [] + ] + single_tokens = [ + token for choice in all_choices_txt for token in choice.tokens or [] + ] + single_usage = Usage( + completion_tokens=sum(usg.completion_tokens for usg in all_usages), + prompt_tokens=sum(usg.prompt_tokens for usg in all_usages), + total_tokens=sum(usg.total_tokens for usg in all_usages), + ) + new_choices = [ + LMModelChoice( + text=single_text, + token_logprobs=single_logprobs, + tokens=single_tokens, + ) + ] + new_responses = ModelChoices(choices=new_choices) # type: ignore + new_usages = Usages(usages=[single_usage]) + new_request.prompt = single_prompt + response_obj = cls( + response=new_responses, + cached=any(res.is_cached() for res in responses), + request=new_request, + usages=new_usages, + request_type=request_type, + response_type=response_type, + ) + return response_obj + else: + new_request.prompt = all_prompts + new_response = ModelChoices(choices=all_choices) + new_usages = Usages(usages=all_usages) + response_obj = cls( + response=new_response, + cached=any(res.is_cached() for res in responses), + request=new_request, + usages=new_usages, + request_type=request_type, + response_type=response_type, + ) + return response_obj + + # Return a token by token iterator over the response + def as_iter(self) -> Generator["Response", None, None]: + """Return a token by token iterator over the response. + + Will return iterator of responses with one token each. + """ + if self._response_type not in {"text"}: + raise ValueError( + f"Invalid response type {self._response_type} for as_iter()" + ) + if not self._response.choices: + raise ValueError("No choices in response.") + if len(self._response.choices) > 1: + raise ValueError( + "Response has more than one choice. as_iter() " + "should be over single choice responses." + ) + if not isinstance(self._response.choices[0], LMModelChoice): + raise ValueError( + "response_type is text but response is " + f"{self._response.choices[0].__class__}" + ) + choice = cast(LMModelChoice, self._response.choices[0]) + # If tokens, return iterator of tokens + if choice.tokens: + for token, logprob in zip(choice.tokens, choice.token_logprobs): + yield Response( + response=ModelChoices( + choices=[ + LMModelChoice( + text=token, token_logprobs=[logprob], tokens=[token] + ) + ] + ), + cached=self._cached, + request=self._request, + usages=self._usages, + request_type=self._request_type, + response_type=self._response_type, + ) + # Otherwise, do it by words + else: + for i, word in enumerate(choice.text.split(" ")): + word = " " + word if i > 0 else word + yield Response( + response=ModelChoices( + choices=[ + LMModelChoice(text=word, token_logprobs=None, tokens=None) + ] + ), + cached=self._cached, + request=self._request, + usages=self._usages, + request_type=self._request_type, + response_type=self._response_type, + ) def serialize(self) -> str: """ diff --git a/tests/conftest.py b/tests/conftest.py index 1a300a5..b3a3e75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,8 +17,10 @@ def model_choice() -> ModelChoices: """Get dummy model choice.""" model_choices = ModelChoices( choices=[ - LMModelChoice(text="hello", token_logprobs=[0.1, 0.2]), - LMModelChoice(text="bye", token_logprobs=[0.3]), + LMModelChoice( + text="hello", token_logprobs=[0.1, 0.2], tokens=["hel", "lo"] + ), + LMModelChoice(text="bye", token_logprobs=[0.3], tokens=["bye"]), ] ) return model_choices @@ -29,7 +31,9 @@ def model_choice_single() -> ModelChoices: """Get dummy model choice.""" model_choices = ModelChoices( choices=[ - LMModelChoice(text="helloo", token_logprobs=[0.1, 0.2]), + LMModelChoice( + text="helloo", token_logprobs=[0.1, 0.2], tokens=["hel", "loo"] + ), ] ) return model_choices diff --git a/tests/test_manifest.py b/tests/test_manifest.py index e667a54..8b7ac6c 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -1,7 +1,7 @@ """Manifest test.""" import asyncio import os -from typing import cast +from typing import Iterator, cast from unittest.mock import MagicMock, Mock, patch import numpy as np @@ -787,6 +787,45 @@ def test_openai(sqlite_cache: str) -> None: ) assert response.is_cached() is True + # Test streaming + num_responses = 0 + streaming_response_text = cast( + Iterator[str], client.run("Why are there oranges?", stream=True) + ) + for res_text in streaming_response_text: + num_responses += 1 + assert isinstance(res_text, str) and len(res_text) > 0 + assert num_responses == 8 + + streaming_response = cast( + Iterator[Response], + client.run("Why are there mandarines?", return_response=True, stream=True), + ) + num_responses = 0 + merged_res = [] + for res in streaming_response: + num_responses += 1 + assert isinstance(res, Response) and len(res.get_response()) > 0 + merged_res.append(cast(str, res.get_response())) + assert not res.is_cached() + assert num_responses == 10 + + # Make sure cached + streaming_response = cast( + Iterator[Response], + client.run("Why are there mandarines?", return_response=True, stream=True), + ) + num_responses = 0 + merged_res_cachced = [] + for res in streaming_response: + num_responses += 1 + assert isinstance(res, Response) and len(res.get_response()) > 0 + merged_res_cachced.append(cast(str, res.get_response())) + assert res.is_cached() + # OpenAI stream does not return logprobs, so this is by number of words + assert num_responses == 7 + assert "".join(merged_res) == "".join(merged_res_cachced) + @pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set") @pytest.mark.usefixtures("sqlite_cache") @@ -796,6 +835,7 @@ def test_openaichat(sqlite_cache: str) -> None: client_name="openaichat", cache_name="sqlite", cache_connection=sqlite_cache, + temperature=0.0, ) res = client.run("Why are there apples?") @@ -868,6 +908,45 @@ def test_openaichat(sqlite_cache: str) -> None: response = cast(Response, client.run(chat_dict, return_response=True)) assert response.is_cached() is False + # Test streaming + num_responses = 0 + streaming_response_text = cast( + Iterator[str], client.run("Why are there oranges?", stream=True) + ) + for res_text in streaming_response_text: + num_responses += 1 + assert isinstance(res_text, str) and len(res_text) > 0 + assert num_responses == 9 + + streaming_response = cast( + Iterator[Response], + client.run("Why are there mandarines?", return_response=True, stream=True), + ) + num_responses = 0 + merged_res = [] + for res in streaming_response: + num_responses += 1 + assert isinstance(res, Response) and len(res.get_response()) > 0 + merged_res.append(cast(str, res.get_response())) + assert not res.is_cached() + assert num_responses == 10 + + # Make sure cached + streaming_response = cast( + Iterator[Response], + client.run("Why are there mandarines?", return_response=True, stream=True), + ) + num_responses = 0 + merged_res_cachced = [] + for res in streaming_response: + num_responses += 1 + assert isinstance(res, Response) and len(res.get_response()) > 0 + merged_res_cachced.append(cast(str, res.get_response())) + assert res.is_cached() + # OpenAI stream does not return logprobs, so this is by number of words + assert num_responses == 7 + assert "".join(merged_res) == "".join(merged_res_cachced) + @pytest.mark.skipif(not OPENAI_ALIVE, reason="No openai key set") @pytest.mark.usefixtures("sqlite_cache") @@ -1156,7 +1235,7 @@ def test_retry_handling() -> None: with patch("manifest.clients.client.requests.post", mock_create): # Run manifest result = client.run(prompts, temperature=0, overwrite_cache=True) - assert result == ["WHATTT.", "UH OH.", "HARG"] + assert result == [" WHATTT.", " UH OH.", " HARG"] # Assert that OpenAI client was called twice assert mock_create.call_count == 2 diff --git a/tests/test_response.py b/tests/test_response.py index 3208876..eac0123 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -6,7 +6,13 @@ import pytest from manifest import Response from manifest.request import EmbeddingRequest, LMRequest -from manifest.response import ArrayModelChoice, ModelChoices, Usage, Usages +from manifest.response import ( + ArrayModelChoice, + LMModelChoice, + ModelChoices, + Usage, + Usages, +) def test_init( @@ -275,9 +281,9 @@ def test_union_all( final_response = Response.union_all([response1, response2]) assert final_response.get_json_response() == { "choices": [ - {"text": "hello", "token_logprobs": [0.1, 0.2], "tokens": None}, - {"text": "bye", "token_logprobs": [0.3], "tokens": None}, - {"text": "helloo", "token_logprobs": [0.1, 0.2], "tokens": None}, + {"text": "hello", "token_logprobs": [0.1, 0.2], "tokens": ["hel", "lo"]}, + {"text": "bye", "token_logprobs": [0.3], "tokens": ["bye"]}, + {"text": "helloo", "token_logprobs": [0.1, 0.2], "tokens": ["hel", "loo"]}, ] } assert final_response.get_usage_obj() == Usages(usages=[Usage(), Usage(), Usage()]) @@ -299,3 +305,83 @@ def test_union_all( assert final_response.get_usage_obj() == Usages( usages=[Usage(total_tokens=4), Usage(total_tokens=6), Usage()] ) + + # Test merge to single + model_choices = ModelChoices( + choices=[ + LMModelChoice( + text=" helloo this is a bug", + token_logprobs=[0.1, 0.2, 0.3], + tokens=[" helloo", " this is", " a bug"], + ), + ] + ) + request = LMRequest(prompt="monkey", engine="dummy") + response1 = Response( + response=model_choices, + cached=False, + request=request, + usages=None, + request_type=LMRequest, + response_type="text", + ) + final_response = Response.union_all([response1, response1], as_single_lmchoice=True) + assert final_response.get_json_response() == { + "choices": [ + { + "text": " helloo this is a bug helloo this is a bug", + "token_logprobs": [0.1, 0.2, 0.3, 0.1, 0.2, 0.3], + "tokens": [ + " helloo", + " this is", + " a bug", + " helloo", + " this is", + " a bug", + ], + }, + ] + } + assert final_response.get_usage_obj() == Usages(usages=[Usage()]) + assert final_response.get_request_obj().prompt == "monkey" + assert final_response.get_request_obj().engine == "dummy" + + +def test_as_iter( + model_choice_single: ModelChoices, request_lm_single: LMRequest +) -> None: + """Test as iter.""" + response = Response( + response=model_choice_single, + cached=False, + request=request_lm_single, + usages=None, + request_type=LMRequest, + response_type="text", + ) + response_iter_list = list(response.as_iter()) + assert len(response_iter_list) == 2 + assert response_iter_list[0].get_response() == "hel" + assert response_iter_list[1].get_response() == "loo" + + model_choices = ModelChoices( + choices=[ + LMModelChoice(text="helloo this is a bug"), + ] + ) + request = LMRequest(prompt="monkey", engine="dummy") + response = Response( + response=model_choices, + cached=False, + request=request, + usages=None, + request_type=LMRequest, + response_type="text", + ) + response_iter_list = list(response.as_iter()) + assert len(response_iter_list) == 5 + assert response_iter_list[0].get_response() == "helloo" + assert response_iter_list[1].get_response() == " this" + assert response_iter_list[2].get_response() == " is" + assert response_iter_list[3].get_response() == " a" + assert response_iter_list[4].get_response() == " bug"