Merge pull request #1 from HazyResearch/prompt

[prompt] add serialization
laurel/helm
Laurel Orr 2 years ago committed by GitHub
commit 020bbb37b4

@ -23,7 +23,7 @@ jobs:
# Label used to access the service container
redis:
# Docker Hub image
image: redislabs/redis
image: redis
# Set health checks to wait until redis has started
options: >-
--health-cmd "redis-cli ping"
@ -62,4 +62,4 @@ jobs:
make check
- name: Test with pytest
run: |
poetry run pytest tests
poetry run pytest tests

@ -18,4 +18,11 @@ or
```
pip install poetry
make dev
```
```
# Development
Before submitting a PR, run
```
export REDIS_PORT="6379" # or whatever PORT local redis is running for those tests
make test
```

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple
from manifest.clients import Response
from manifest.clients.response import Response
class Client(ABC):

@ -23,8 +23,8 @@ class Response:
if len(self.response["choices"]) > 0:
if "text" not in self.response["choices"][0]:
raise ValueError(
"Response must be serialized to a dict with a ",
"list of choices with text field",
"Response must be serialized to a dict with a "
"list of choices with text field"
)
def __getitem__(self, key: str) -> str:

@ -5,11 +5,11 @@ from typing import Any, Iterable, List, Optional, Union
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
from manifest import Prompt
from manifest.caches.redis import RedisCache
from manifest.caches.sqlite import SQLiteCache
from manifest.clients.dummy import DummyClient
from manifest.clients.openai import OpenAIClient
from manifest.prompt import Prompt
CLIENT_CONSTRUCTORS = {
"openai": OpenAIClient,

@ -4,6 +4,8 @@ import inspect
import logging
from typing import Any, Callable, List, Optional, Union
import dill
logging.getLogger("openai").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
@ -52,10 +54,9 @@ class Prompt:
Return the prompt as str.
Returns:
json object.
prompt as str.
"""
# TODO: implement
pass
return dill.dumps(self.prompt_func)
@classmethod
def deserialize(cls, obj: str) -> "Prompt":
@ -68,5 +69,4 @@ class Prompt:
Return:
prompt.
"""
# TODO: implement
pass
return Prompt(dill.loads(obj))

17
poetry.lock generated

@ -154,6 +154,17 @@ wrapt = ">=1.10,<2"
[package.extras]
dev = ["tox", "bump2version (<1)", "sphinx (<2)", "importlib-metadata (<3)", "importlib-resources (<4)", "configparser (<5)", "sphinxcontrib-websupport (<2)", "zipp (<2)", "PyTest (<5)", "PyTest-Cov (<2.6)", "pytest", "pytest-cov"]
[[package]]
name = "dill"
version = "0.3.5.1"
description = "serialize all of python"
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
[package.extras]
graph = ["objgraph (>=1.7.2)"]
[[package]]
name = "distlib"
version = "0.3.4"
@ -864,7 +875,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
content-hash = "6423e5907172dc326eb2deaed10a5dcce8c721b043e5e205d7116be715b2d867"
content-hash = "cfa01a852b186ab6ddd1b610dba12c7d4ac6f10c7bdb10ca7bac6a0d25568d00"
[metadata.files]
alabaster = [
@ -983,6 +994,10 @@ deprecated = [
{file = "Deprecated-1.2.13-py2.py3-none-any.whl", hash = "sha256:64756e3e14c8c5eea9795d93c524551432a0be75629f8f29e67ab8caf076c76d"},
{file = "Deprecated-1.2.13.tar.gz", hash = "sha256:43ac5335da90c31c24ba028af536a91d41d53f9e6901ddb021bcc572ce44e38d"},
]
dill = [
{file = "dill-0.3.5.1-py2.py3-none-any.whl", hash = "sha256:33501d03270bbe410c72639b350e941882a8b0fd55357580fbc873fba0c59302"},
{file = "dill-0.3.5.1.tar.gz", hash = "sha256:d75e41f3eff1eee599d738e76ba8f4ad98ea229db8b085318aa2b3333a208c86"},
]
distlib = [
{file = "distlib-0.3.4-py2.py3-none-any.whl", hash = "sha256:6564fe0a8f51e734df6333d08b8b94d4ea8ee6b99b5ed50613f731fd4089f34b"},
{file = "distlib-0.3.4.zip", hash = "sha256:e4b58818180336dc9c529bfb9a0b58728ffc09ad92027a3f30b7cd91e3458579"},

@ -20,6 +20,7 @@ python = "^3.8"
sqlitedict = "^2.0.0"
openai = "^0.18.1"
redis = "^4.3.1"
dill = "^0.3.5"
[tool.poetry.dev-dependencies]
black = "^22.3.0"
@ -49,6 +50,7 @@ module = [
"numpy",
"tqdm",
"sqlitedict",
"dill",
]
[tool.isort]
@ -66,4 +68,4 @@ addopts = "-v -rsXx"
# The following options are useful for local debugging
# addopts = "-v -rsXx -s -x --pdb"
# log_cli_level = "DEBUG"
# log_cli = true
# log_cli = true

@ -20,8 +20,8 @@ def test_init():
with pytest.raises(ValueError) as exc_info:
response = Response({"choices": [{"blah": "hello"}]})
assert str(exc_info.value) == (
"Response must be serialized to a dict ",
"with a list of choices with text field",
"Response must be serialized to a dict "
"with a list of choices with text field"
)
response = Response({"choices": [{"text": "hello"}]})

@ -39,7 +39,6 @@ def test_init():
assert str(exc_info.value) == "Prompts must have zero or one input."
@pytest.mark.skip(reason="Not implemented")
def test_serialize():
"""Test prompt serialization."""
str_prompt = "This is a test prompt"
@ -47,8 +46,10 @@ def test_serialize():
# String prompt
prompt = Prompt(str_prompt)
assert Prompt.deserialize(prompt.serialize()) == prompt
assert Prompt.deserialize(prompt.serialize()).prompt_func() == prompt.prompt_func()
# Function single inputs
prompt = Prompt(func_single_prompt)
assert Prompt.deserialize(prompt.serialize()) == prompt
assert Prompt.deserialize(prompt.serialize()).prompt_func(1) == prompt.prompt_func(
1
)

Loading…
Cancel
Save