diff --git a/README.md b/README.md index 3962c50..beb8dec 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -# manifest -Prompt programming with FMs. +# Manifest +How to make prompt programming with FMs a little easier. # Install Download the code: @@ -11,29 +11,63 @@ cd manifest Install: ```bash pip install poetry -poetry install -poetry run pre-commit install +poetry install --no-dev ``` -or + +Dev Install: ```bash pip install poetry make dev ``` -# Run -Manifest is meant to be a very light weight package to help with prompt iteration. Two key design decisions are + +# Getting Started +Running is simple to get started. If using OpenAI, set `export OPENAI_API_KEY=` then run + +```python +from manifest import Manifest + +# Start a manifest session +manifest = Manifest( + client_name = "openai", +) +manifest.run("Why is the grass green?") +``` + +We also support AI21, OPT models, and HuggingFace models (see [below](#huggingface-models)). + +Caching by default is turned off, but to cache results, run + +```python +from manifest import Manifest + +# Start a manifest session +manifest = Manifest( + client_name = "openai", + cache_name = "sqlite", + cache_connection = "mycache.sqlite", +) +manifest.run("Why is the grass green?") +``` + +We also support Redis backend. + +# Manifest Components +Manifest is meant to be a very light weight package to help with prompt iteration. Three key design decisions are * Prompt are functional -- they can take an input example and dynamically change * All models are behind API calls (e.g., OpenAI) -* Everything is cached for reuse to both save credits and to explore past results +* Everything can cached for reuse to both save credits and to explore past results ## Prompts A Manifest prompt is a function that accepts a single input to generate a string prompt to send to a model. + ```python from manifest import Prompt prompt = Prompt(lambda x: "Hello, my name is {x}") print(prompt("Laurel")) >>> "Hello, my name is Laurel" ``` + We also let you use static strings ```python prompt = Prompt("Hello, my name is static") @@ -41,8 +75,6 @@ print(prompt()) >>> "Hello, my name is static" ``` -**Chaining prompts coming soon** - ## Sessions Each Manifest run is a session that connects to a model endpoint and backend database to record prompt queries. To start a Manifest session for OpenAI, make sure you run @@ -51,7 +83,7 @@ export OPENAI_API_KEY= ``` so we can access OpenAI. -Then, in a notebook, run: +Then run: ```python from manifest import Manifest @@ -104,7 +136,8 @@ If you want to change default parameters to a model, we pass those as `kwargs` t ```python result = manifest.run(prompt, "Laurel", max_tokens=50) ``` -# Huggingface Models + +## Huggingface Models To use a HuggingFace generative model, in `manifest/api` we have a Falsk application that hosts the models for you. In a separate terminal or Tmux/Screen session, run @@ -117,15 +150,11 @@ You will see the Flask session start and output a URL `http://127.0.0.1:5000`. P manifest = Manifest( client_name = "huggingface", client_connection = "http://127.0.0.1:5000", - cache_name = "redis", - cache_connection = "localhost:6379" ) ``` If you have a custom model you trained, pass the model path to `--model_name`. -**Auto deployment coming soon** - # Development Before submitting a PR, run ```bash diff --git a/manifest/clients/ai21.py b/manifest/clients/ai21.py index 00dfe58..c9c331b 100644 --- a/manifest/clients/ai21.py +++ b/manifest/clients/ai21.py @@ -66,7 +66,39 @@ class AI21Client(Client): """ return {"model_name": "ai21", "engine": self.engine} - def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: + def format_response(self, response: Dict) -> Dict[str, Any]: + """ + Format response to dict. + + Args: + response: response + + Return: + response as dict + """ + return { + "object": "text_completion", + "model": self.engine, + "choices": [ + { + "text": item["data"]["text"], + "logprobs": [ + { + "token": tok["generatedToken"]["token"], + "logprob": tok["generatedToken"]["logprob"], + "start": tok["textRange"]["start"], + "end": tok["textRange"]["end"], + } + for tok in item["data"]["tokens"] + ], + } + for item in response["completions"] + ], + } + + def get_request( + self, query: str, request_args: Dict[str, Any] = {} + ) -> Tuple[Callable[[], Dict], Dict]: """ Get request string function. @@ -78,26 +110,22 @@ class AI21Client(Client): request parameters as dict. """ request_params = { - "engine": kwargs.get("engine", self.engine), + "engine": request_args.pop("engine", self.engine), "prompt": query, - "temperature": kwargs.get("temperature", self.temperature), - "maxTokens": kwargs.get("maxTokens", self.max_tokens), - "topKReturn": kwargs.get("topKReturn", self.top_k_return), - "numResults": kwargs.get("numResults", self.num_results), - "topP": kwargs.get("topP", self.top_p), + "temperature": request_args.pop("temperature", self.temperature), + "maxTokens": request_args.pop("maxTokens", self.max_tokens), + "topKReturn": request_args.pop("topKReturn", self.top_k_return), + "numResults": request_args.pop("numResults", self.num_results), + "topP": request_args.pop("topP", self.top_p), } def _run_completion() -> Dict: post_str = self.host + "/" + self.engine + "/complete" - print(self.api_key) - print(post_str) - print("https://api.ai21.com/studio/v1/j1-large/complete") - print(request_params) res = requests.post( post_str, headers={"Authorization": f"Bearer {self.api_key}"}, json=request_params, ) - return res.json() + return self.format_response(res.json()) return _run_completion, request_params diff --git a/manifest/clients/client.py b/manifest/clients/client.py index 58b391f..20f1ffc 100644 --- a/manifest/clients/client.py +++ b/manifest/clients/client.py @@ -52,7 +52,9 @@ class Client(ABC): raise NotImplementedError() @abstractmethod - def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: + def get_request( + self, query: str, request_args: Dict[str, Any] = {} + ) -> Tuple[Callable[[], Dict], Dict]: """ Get request function. @@ -62,6 +64,7 @@ class Client(ABC): Args: query: query string. + request_args: request arguments. Returns: request function that takes no input. diff --git a/manifest/clients/crfm.py b/manifest/clients/crfm.py index f588a4d..e7ff3f6 100644 --- a/manifest/clients/crfm.py +++ b/manifest/clients/crfm.py @@ -86,6 +86,7 @@ class CRFMClient(Client): response as dict """ return { + "id": response.id, "object": "text_completion", "model": self.engine, "choices": [ @@ -104,7 +105,9 @@ class CRFMClient(Client): ], } - def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: + def get_request( + self, query: str, request_args: Dict[str, Any] = {} + ) -> Tuple[Callable[[], Dict], Dict]: """ Get request string function. @@ -116,16 +119,22 @@ class CRFMClient(Client): request parameters as dict. """ request_params = { - "model": kwargs.get("engine", self.engine), + "model": request_args.pop("engine", self.engine), "prompt": query, - "temperature": kwargs.get("temperature", self.temperature), - "max_tokens": kwargs.get("max_tokens", self.max_tokens), - "top_k_per_token": kwargs.get("top_k_per_token", self.top_k_per_token), - "num_completions": kwargs.get("num_completions", self.num_completions), - "stop_sequences": kwargs.get("stop_sequences", self.stop_sequences), - "top_p": kwargs.get("top_p", self.top_p), - "presence_penalty": kwargs.get("presence_penalty", self.presence_penalty), - "frequency_penalty": kwargs.get( + "temperature": request_args.pop("temperature", self.temperature), + "max_tokens": request_args.pop("max_tokens", self.max_tokens), + "top_k_per_token": request_args.pop( + "top_k_per_token", self.top_k_per_token + ), + "num_completions": request_args.pop( + "num_completions", self.num_completions + ), + "stop_sequences": request_args.pop("stop_sequences", self.stop_sequences), + "top_p": request_args.pop("top_p", self.top_p), + "presence_penalty": request_args.pop( + "presence_penalty", self.presence_penalty + ), + "frequency_penalty": request_args.pop( "frequency_penalty", self.frequency_penalty ), } diff --git a/manifest/clients/dummy.py b/manifest/clients/dummy.py index d57172e..f6df36c 100644 --- a/manifest/clients/dummy.py +++ b/manifest/clients/dummy.py @@ -42,7 +42,9 @@ class DummyClient(Client): """ return {"engine": "dummy"} - def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: + def get_request( + self, query: str, request_args: Dict[str, Any] = {} + ) -> Tuple[Callable[[], Dict], Dict]: """ Get request string function. @@ -55,10 +57,10 @@ class DummyClient(Client): """ request_params = { "prompt": query, - "num_results": kwargs.get("num_results", self.num_results), + "num_results": request_args.pop("num_results", self.num_results), } def _run_completion() -> Dict: - return {"choices": [{"text": "hello"}] * self.num_results} + return {"choices": [{"text": "hello"}] * request_params["num_results"]} return _run_completion, request_params diff --git a/manifest/clients/huggingface.py b/manifest/clients/huggingface.py index cf53327..13ec90e 100644 --- a/manifest/clients/huggingface.py +++ b/manifest/clients/huggingface.py @@ -51,7 +51,9 @@ class HuggingFaceClient(Client): res = requests.post(self.host + "/params") return res.json() - def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: + def get_request( + self, query: str, request_args: Dict[str, Any] = {} + ) -> Tuple[Callable[[], Dict], Dict]: """ Get request string function. @@ -64,15 +66,15 @@ class HuggingFaceClient(Client): """ request_params = { "prompt": query, - "temperature": kwargs.get("temperature", self.temperature), - "max_tokens": kwargs.get("max_tokens", self.max_tokens), - "top_p": kwargs.get("top_p", self.top_p), - "top_k": kwargs.get("top_k", self.top_k), - "do_sample": kwargs.get("do_sample", self.do_sample), - "repetition_penalty": kwargs.get( + "temperature": request_args.pop("temperature", self.temperature), + "max_tokens": request_args.pop("max_tokens", self.max_tokens), + "top_p": request_args.pop("top_p", self.top_p), + "top_k": request_args.pop("top_k", self.top_k), + "do_sample": request_args.pop("do_sample", self.do_sample), + "repetition_penalty": request_args.pop( "repetition_penalty", self.repetition_penalty ), - "n": kwargs.get("n", self.n), + "n": request_args.pop("n", self.n), } request_params.update(self.model_params) diff --git a/manifest/clients/openai.py b/manifest/clients/openai.py index 6bb1ae3..aba4c79 100644 --- a/manifest/clients/openai.py +++ b/manifest/clients/openai.py @@ -71,7 +71,9 @@ class OpenAIClient(Client): """ return {"model_name": "openai", "engine": self.engine} - def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: + def get_request( + self, query: str, request_args: Dict[str, Any] = {} + ) -> Tuple[Callable[[], Dict], Dict]: """ Get request string function. @@ -83,18 +85,20 @@ class OpenAIClient(Client): request parameters as dict. """ request_params = { - "engine": kwargs.get("engine", self.engine), + "engine": request_args.pop("engine", self.engine), "prompt": query, - "temperature": kwargs.get("temperature", self.temperature), - "max_tokens": kwargs.get("max_tokens", self.max_tokens), - "top_p": kwargs.get("top_p", self.top_p), - "frequency_penalty": kwargs.get( + "temperature": request_args.pop("temperature", self.temperature), + "max_tokens": request_args.pop("max_tokens", self.max_tokens), + "top_p": request_args.pop("top_p", self.top_p), + "frequency_penalty": request_args.pop( "frequency_penalty", self.frequency_penalty ), - "logprobs": kwargs.get("logprobs", self.logprobs), - "best_of": kwargs.get("best_of", self.best_of), - "presence_penalty": kwargs.get("presence_penalty", self.presence_penalty), - "n": kwargs.get("n", self.n), + "logprobs": request_args.pop("logprobs", self.logprobs), + "best_of": request_args.pop("best_of", self.best_of), + "presence_penalty": request_args.pop( + "presence_penalty", self.presence_penalty + ), + "n": request_args.pop("n", self.n), } def _run_completion() -> Dict: diff --git a/manifest/clients/opt.py b/manifest/clients/opt.py index 91592e4..915e0db 100644 --- a/manifest/clients/opt.py +++ b/manifest/clients/opt.py @@ -46,7 +46,9 @@ class OPTClient(Client): """ return {"model_name": "opt"} - def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: + def get_request( + self, query: str, request_args: Dict[str, Any] = {} + ) -> Tuple[Callable[[], Dict], Dict]: """ Get request string function. @@ -60,10 +62,10 @@ class OPTClient(Client): request_params = { "prompt": query, "engine": "opt", - "temperature": kwargs.get("temperature", self.temperature), - "max_tokens": kwargs.get("max_tokens", self.max_tokens), - "top_p": kwargs.get("top_p", self.top_p), - "n": kwargs.get("n", self.n), + "temperature": request_args.pop("temperature", self.temperature), + "max_tokens": request_args.pop("max_tokens", self.max_tokens), + "top_p": request_args.pop("top_p", self.top_p), + "n": request_args.pop("n", self.n), } def _run_completion() -> Dict: diff --git a/manifest/manifest.py b/manifest/manifest.py index 6acc82b..6a7ac77 100644 --- a/manifest/manifest.py +++ b/manifest/manifest.py @@ -121,7 +121,10 @@ class Manifest: prompt = Prompt(prompt) stop_token = stop_token if stop_token is not None else self.stop_token prompt_str = prompt(input) - possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs) + # Must pass kwargs as dict for client "pop" methods removed used arguments + possible_request, full_kwargs = self.client.get_request(prompt_str, kwargs) + if len(kwargs) > 0: + raise ValueError(f"{list(kwargs.items())} arguments are not recognized.") # Create cacke key cache_key = full_kwargs.copy() # Make query model dependent diff --git a/manifest/response.py b/manifest/response.py index 240f535..11034ca 100644 --- a/manifest/response.py +++ b/manifest/response.py @@ -25,6 +25,12 @@ class Response: "Response must be serialized to a dict with a " "list of choices with text field" ) + if "logprobs" in self._response["choices"][0]: + if not isinstance(self._response["choices"][0]["logprobs"], list): + raise ValueError( + "Response must be serialized to a dict with a " + "list of choices with logprobs field" + ) self._cached = cached self._request_params = request_params diff --git a/pyproject.toml b/pyproject.toml index 29a9a47..8384359 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,16 +22,11 @@ openai = "^0.18.1" redis = "^4.3.1" dill = "^0.3.5" Flask = "^2.1.2" -#transformers = "^4.19.2" -#torch = "^1.11.0" +transformers = "^4.19.2" +torch = "^1.8" requests = "^2.27.1" tqdm = "^4.64.0" -types-redis = "^4.2.6" -types-requests = "^2.27.29" -types-PyYAML = "^6.0.7" -types-protobuf = "^3.19.21" -types-python-dateutil = "^2.8.16" -types-setuptools = "^57.4.17" +uuid = "^1.30" [tool.poetry.dev-dependencies] black = "^22.3.0" @@ -45,6 +40,12 @@ pytest = "^7.0.0" pytest-cov = "^3.0.0" python-dotenv = "^0.20.0" recommonmark = "^0.7.1" +types-redis = "^4.2.6" +types-requests = "^2.27.29" +types-PyYAML = "^6.0.7" +types-protobuf = "^3.19.21" +types-python-dateutil = "^2.8.16" +types-setuptools = "^57.4.17" [build-system] build-backend = "poetry.core.masonry.api" diff --git a/tests/test_client.py b/tests/test_client.py index 1e9acbf..c2ae20c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,12 +3,14 @@ Test client. We just test the dummy client as we don't want to load a model or use OpenAI tokens. """ - from manifest.clients.dummy import DummyClient def test_init(): """Test client initialization.""" + client = DummyClient(connection_str=None) + assert client.num_results == 1 + args = {"num_results": 3} client = DummyClient(connection_str=None, client_args=args) assert client.num_results == 3 @@ -21,3 +23,7 @@ def test_get_request(): request_func, request_params = client.get_request("hello") assert request_params == {"prompt": "hello", "num_results": 3} assert request_func() == {"choices": [{"text": "hello"}] * 3} + + request_func, request_params = client.get_request("hello", {"num_results": 5}) + assert request_params == {"prompt": "hello", "num_results": 5} + assert request_func() == {"choices": [{"text": "hello"}] * 5} diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 7b4207e..d955570 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -55,6 +55,12 @@ def test_run(sqlite_cache, num_results, return_response): cache_connection=sqlite_cache, num_results=num_results, ) + + prompt = Prompt("This is a prompt") + with pytest.raises(ValueError) as exc_info: + result = manifest.run(prompt, return_response=return_response, bad_input=5) + assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized." + prompt = Prompt("This is a prompt") result = manifest.run(prompt, return_response=return_response) if return_response: