Save Prompts (#194)

harrison/logging_to_file
Akash Samant 2 years ago committed by GitHub
parent b90e25f786
commit ae72cf84b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,10 @@
"""BasePrompt schema definition."""
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from pathlib import Path
from typing import Any, Dict, List, Union
import yaml
from pydantic import BaseModel, root_validator
from langchain.formatting import formatter
@ -61,3 +64,39 @@ class BasePromptTemplate(BaseModel, ABC):
prompt.format(variable1="foo")
"""
def _prompt_dict(self) -> Dict:
"""Return a dictionary of the prompt."""
return self.dict()
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
Args:
file_path: Path to directory to save prompt to.
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self._prompt_dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
f.write(json.dumps(prompt_dict, indent=4))
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")

@ -108,3 +108,12 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
template = self.example_separator.join([piece for piece in pieces if piece])
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
def _prompt_dict(self) -> Dict:
"""Return a dictionary of the prompt."""
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
prompt_dict = self.dict()
prompt_dict["_type"] = "few_shot"
return prompt_dict

@ -43,6 +43,34 @@ def test_loading_from_JSON() -> None:
assert prompt == expected_prompt
def test_saving_loading_round_trip(tmp_path: Path) -> None:
"""Test equality when saving and loading a prompt."""
simple_prompt = PromptTemplate(
input_variables=["adjective", "content"],
template="Tell me a {adjective} joke about {content}.",
)
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
assert loaded_prompt == simple_prompt
few_shot_prompt = FewShotPromptTemplate(
input_variables=["adjective"],
prefix="Write antonyms for the following words.",
example_prompt=PromptTemplate(
input_variables=["input", "output"],
template="Input: {input}\nOutput: {output}",
),
examples=[
{"input": "happy", "output": "sad"},
{"input": "tall", "output": "short"},
],
suffix="Input: {adjective}\nOutput:",
)
few_shot_prompt.save(file_path=tmp_path / "few_shot.yaml")
loaded_prompt = load_prompt(tmp_path / "few_shot.yaml")
assert loaded_prompt == few_shot_prompt
def test_loading_with_template_as_file() -> None:
"""Test loading when the template is a file."""
with change_directory():

Loading…
Cancel
Save