Merge pull request #11 from HazyResearch/laurel/dev

build: fixing transformers import
laurel/helm
Laurel Orr 2 years ago committed by GitHub
commit 382640c3e6

@ -1,5 +1,5 @@
# manifest
Prompt programming with FMs.
# Manifest
How to make prompt programming with FMs a little easier.
# Install
Download the code:
@ -11,29 +11,63 @@ cd manifest
Install:
```bash
pip install poetry
poetry install
poetry run pre-commit install
poetry install --no-dev
```
or
Dev Install:
```bash
pip install poetry
make dev
```
# Run
Manifest is meant to be a very light weight package to help with prompt iteration. Two key design decisions are
# Getting Started
Running is simple to get started. If using OpenAI, set `export OPENAI_API_KEY=<OPENAIKEY>` then run
```python
from manifest import Manifest
# Start a manifest session
manifest = Manifest(
client_name = "openai",
)
manifest.run("Why is the grass green?")
```
We also support AI21, OPT models, and HuggingFace models (see [below](#huggingface-models)).
Caching by default is turned off, but to cache results, run
```python
from manifest import Manifest
# Start a manifest session
manifest = Manifest(
client_name = "openai",
cache_name = "sqlite",
cache_connection = "mycache.sqlite",
)
manifest.run("Why is the grass green?")
```
We also support Redis backend.
# Manifest Components
Manifest is meant to be a very light weight package to help with prompt iteration. Three key design decisions are
* Prompt are functional -- they can take an input example and dynamically change
* All models are behind API calls (e.g., OpenAI)
* Everything is cached for reuse to both save credits and to explore past results
* Everything can cached for reuse to both save credits and to explore past results
## Prompts
A Manifest prompt is a function that accepts a single input to generate a string prompt to send to a model.
```python
from manifest import Prompt
prompt = Prompt(lambda x: "Hello, my name is {x}")
print(prompt("Laurel"))
>>> "Hello, my name is Laurel"
```
We also let you use static strings
```python
prompt = Prompt("Hello, my name is static")
@ -41,8 +75,6 @@ print(prompt())
>>> "Hello, my name is static"
```
**Chaining prompts coming soon**
## Sessions
Each Manifest run is a session that connects to a model endpoint and backend database to record prompt queries. To start a Manifest session for OpenAI, make sure you run
@ -51,7 +83,7 @@ export OPENAI_API_KEY=<OPENAIKEY>
```
so we can access OpenAI.
Then, in a notebook, run:
Then run:
```python
from manifest import Manifest
@ -104,7 +136,8 @@ If you want to change default parameters to a model, we pass those as `kwargs` t
```python
result = manifest.run(prompt, "Laurel", max_tokens=50)
```
# Huggingface Models
## Huggingface Models
To use a HuggingFace generative model, in `manifest/api` we have a Falsk application that hosts the models for you.
In a separate terminal or Tmux/Screen session, run
@ -117,15 +150,11 @@ You will see the Flask session start and output a URL `http://127.0.0.1:5000`. P
manifest = Manifest(
client_name = "huggingface",
client_connection = "http://127.0.0.1:5000",
cache_name = "redis",
cache_connection = "localhost:6379"
)
```
If you have a custom model you trained, pass the model path to `--model_name`.
**Auto deployment coming soon**
# Development
Before submitting a PR, run
```bash

@ -66,7 +66,39 @@ class AI21Client(Client):
"""
return {"model_name": "ai21", "engine": self.engine}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def format_response(self, response: Dict) -> Dict[str, Any]:
"""
Format response to dict.
Args:
response: response
Return:
response as dict
"""
return {
"object": "text_completion",
"model": self.engine,
"choices": [
{
"text": item["data"]["text"],
"logprobs": [
{
"token": tok["generatedToken"]["token"],
"logprob": tok["generatedToken"]["logprob"],
"start": tok["textRange"]["start"],
"end": tok["textRange"]["end"],
}
for tok in item["data"]["tokens"]
],
}
for item in response["completions"]
],
}
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -78,26 +110,22 @@ class AI21Client(Client):
request parameters as dict.
"""
request_params = {
"engine": kwargs.get("engine", self.engine),
"engine": request_args.pop("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"maxTokens": kwargs.get("maxTokens", self.max_tokens),
"topKReturn": kwargs.get("topKReturn", self.top_k_return),
"numResults": kwargs.get("numResults", self.num_results),
"topP": kwargs.get("topP", self.top_p),
"temperature": request_args.pop("temperature", self.temperature),
"maxTokens": request_args.pop("maxTokens", self.max_tokens),
"topKReturn": request_args.pop("topKReturn", self.top_k_return),
"numResults": request_args.pop("numResults", self.num_results),
"topP": request_args.pop("topP", self.top_p),
}
def _run_completion() -> Dict:
post_str = self.host + "/" + self.engine + "/complete"
print(self.api_key)
print(post_str)
print("https://api.ai21.com/studio/v1/j1-large/complete")
print(request_params)
res = requests.post(
post_str,
headers={"Authorization": f"Bearer {self.api_key}"},
json=request_params,
)
return res.json()
return self.format_response(res.json())
return _run_completion, request_params

@ -52,7 +52,9 @@ class Client(ABC):
raise NotImplementedError()
@abstractmethod
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request function.
@ -62,6 +64,7 @@ class Client(ABC):
Args:
query: query string.
request_args: request arguments.
Returns:
request function that takes no input.

@ -86,6 +86,7 @@ class CRFMClient(Client):
response as dict
"""
return {
"id": response.id,
"object": "text_completion",
"model": self.engine,
"choices": [
@ -104,7 +105,9 @@ class CRFMClient(Client):
],
}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -116,16 +119,22 @@ class CRFMClient(Client):
request parameters as dict.
"""
request_params = {
"model": kwargs.get("engine", self.engine),
"model": request_args.pop("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_k_per_token": kwargs.get("top_k_per_token", self.top_k_per_token),
"num_completions": kwargs.get("num_completions", self.num_completions),
"stop_sequences": kwargs.get("stop_sequences", self.stop_sequences),
"top_p": kwargs.get("top_p", self.top_p),
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
"frequency_penalty": kwargs.get(
"temperature": request_args.pop("temperature", self.temperature),
"max_tokens": request_args.pop("max_tokens", self.max_tokens),
"top_k_per_token": request_args.pop(
"top_k_per_token", self.top_k_per_token
),
"num_completions": request_args.pop(
"num_completions", self.num_completions
),
"stop_sequences": request_args.pop("stop_sequences", self.stop_sequences),
"top_p": request_args.pop("top_p", self.top_p),
"presence_penalty": request_args.pop(
"presence_penalty", self.presence_penalty
),
"frequency_penalty": request_args.pop(
"frequency_penalty", self.frequency_penalty
),
}

@ -42,7 +42,9 @@ class DummyClient(Client):
"""
return {"engine": "dummy"}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -55,10 +57,10 @@ class DummyClient(Client):
"""
request_params = {
"prompt": query,
"num_results": kwargs.get("num_results", self.num_results),
"num_results": request_args.pop("num_results", self.num_results),
}
def _run_completion() -> Dict:
return {"choices": [{"text": "hello"}] * self.num_results}
return {"choices": [{"text": "hello"}] * request_params["num_results"]}
return _run_completion, request_params

@ -51,7 +51,9 @@ class HuggingFaceClient(Client):
res = requests.post(self.host + "/params")
return res.json()
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -64,15 +66,15 @@ class HuggingFaceClient(Client):
"""
request_params = {
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),
"top_k": kwargs.get("top_k", self.top_k),
"do_sample": kwargs.get("do_sample", self.do_sample),
"repetition_penalty": kwargs.get(
"temperature": request_args.pop("temperature", self.temperature),
"max_tokens": request_args.pop("max_tokens", self.max_tokens),
"top_p": request_args.pop("top_p", self.top_p),
"top_k": request_args.pop("top_k", self.top_k),
"do_sample": request_args.pop("do_sample", self.do_sample),
"repetition_penalty": request_args.pop(
"repetition_penalty", self.repetition_penalty
),
"n": kwargs.get("n", self.n),
"n": request_args.pop("n", self.n),
}
request_params.update(self.model_params)

@ -71,7 +71,9 @@ class OpenAIClient(Client):
"""
return {"model_name": "openai", "engine": self.engine}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -83,18 +85,20 @@ class OpenAIClient(Client):
request parameters as dict.
"""
request_params = {
"engine": kwargs.get("engine", self.engine),
"engine": request_args.pop("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),
"frequency_penalty": kwargs.get(
"temperature": request_args.pop("temperature", self.temperature),
"max_tokens": request_args.pop("max_tokens", self.max_tokens),
"top_p": request_args.pop("top_p", self.top_p),
"frequency_penalty": request_args.pop(
"frequency_penalty", self.frequency_penalty
),
"logprobs": kwargs.get("logprobs", self.logprobs),
"best_of": kwargs.get("best_of", self.best_of),
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
"n": kwargs.get("n", self.n),
"logprobs": request_args.pop("logprobs", self.logprobs),
"best_of": request_args.pop("best_of", self.best_of),
"presence_penalty": request_args.pop(
"presence_penalty", self.presence_penalty
),
"n": request_args.pop("n", self.n),
}
def _run_completion() -> Dict:

@ -46,7 +46,9 @@ class OPTClient(Client):
"""
return {"model_name": "opt"}
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
@ -60,10 +62,10 @@ class OPTClient(Client):
request_params = {
"prompt": query,
"engine": "opt",
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_p": kwargs.get("top_p", self.top_p),
"n": kwargs.get("n", self.n),
"temperature": request_args.pop("temperature", self.temperature),
"max_tokens": request_args.pop("max_tokens", self.max_tokens),
"top_p": request_args.pop("top_p", self.top_p),
"n": request_args.pop("n", self.n),
}
def _run_completion() -> Dict:

@ -121,7 +121,10 @@ class Manifest:
prompt = Prompt(prompt)
stop_token = stop_token if stop_token is not None else self.stop_token
prompt_str = prompt(input)
possible_request, full_kwargs = self.client.get_request(prompt_str, **kwargs)
# Must pass kwargs as dict for client "pop" methods removed used arguments
possible_request, full_kwargs = self.client.get_request(prompt_str, kwargs)
if len(kwargs) > 0:
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")
# Create cacke key
cache_key = full_kwargs.copy()
# Make query model dependent

@ -25,6 +25,12 @@ class Response:
"Response must be serialized to a dict with a "
"list of choices with text field"
)
if "logprobs" in self._response["choices"][0]:
if not isinstance(self._response["choices"][0]["logprobs"], list):
raise ValueError(
"Response must be serialized to a dict with a "
"list of choices with logprobs field"
)
self._cached = cached
self._request_params = request_params

@ -22,16 +22,11 @@ openai = "^0.18.1"
redis = "^4.3.1"
dill = "^0.3.5"
Flask = "^2.1.2"
#transformers = "^4.19.2"
#torch = "^1.11.0"
transformers = "^4.19.2"
torch = "^1.8"
requests = "^2.27.1"
tqdm = "^4.64.0"
types-redis = "^4.2.6"
types-requests = "^2.27.29"
types-PyYAML = "^6.0.7"
types-protobuf = "^3.19.21"
types-python-dateutil = "^2.8.16"
types-setuptools = "^57.4.17"
uuid = "^1.30"
[tool.poetry.dev-dependencies]
black = "^22.3.0"
@ -45,6 +40,12 @@ pytest = "^7.0.0"
pytest-cov = "^3.0.0"
python-dotenv = "^0.20.0"
recommonmark = "^0.7.1"
types-redis = "^4.2.6"
types-requests = "^2.27.29"
types-PyYAML = "^6.0.7"
types-protobuf = "^3.19.21"
types-python-dateutil = "^2.8.16"
types-setuptools = "^57.4.17"
[build-system]
build-backend = "poetry.core.masonry.api"

@ -3,12 +3,14 @@ Test client.
We just test the dummy client as we don't want to load a model or use OpenAI tokens.
"""
from manifest.clients.dummy import DummyClient
def test_init():
"""Test client initialization."""
client = DummyClient(connection_str=None)
assert client.num_results == 1
args = {"num_results": 3}
client = DummyClient(connection_str=None, client_args=args)
assert client.num_results == 3
@ -21,3 +23,7 @@ def test_get_request():
request_func, request_params = client.get_request("hello")
assert request_params == {"prompt": "hello", "num_results": 3}
assert request_func() == {"choices": [{"text": "hello"}] * 3}
request_func, request_params = client.get_request("hello", {"num_results": 5})
assert request_params == {"prompt": "hello", "num_results": 5}
assert request_func() == {"choices": [{"text": "hello"}] * 5}

@ -55,6 +55,12 @@ def test_run(sqlite_cache, num_results, return_response):
cache_connection=sqlite_cache,
num_results=num_results,
)
prompt = Prompt("This is a prompt")
with pytest.raises(ValueError) as exc_info:
result = manifest.run(prompt, return_response=return_response, bad_input=5)
assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized."
prompt = Prompt("This is a prompt")
result = manifest.run(prompt, return_response=return_response)
if return_response:

Loading…
Cancel
Save