Harrison/jinja formatter (#385)

Co-authored-by: Benjamin <BenderV@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2022-12-19 16:40:39 -05:00 committed by GitHub
parent fc66a32c6f
commit ffed5e0056
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 79 additions and 10 deletions

View File

@ -71,7 +71,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 11,
"id": "094229f4", "id": "094229f4",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -151,6 +151,59 @@
"multiple_input_prompt.format(adjective=\"funny\", content=\"chickens\")" "multiple_input_prompt.format(adjective=\"funny\", content=\"chickens\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "b2dd6154",
"metadata": {},
"source": [
"## Alternative formats\n",
"\n",
"This section shows how to use alternative formats besides \"f-string\" to format prompts."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "53b41b6a",
"metadata": {},
"outputs": [],
"source": [
"# Jinja2\n",
"template = \"\"\"\n",
"{% for item in items %}\n",
"Question: {{ item.question }}\n",
"Answer: {{ item.answer }}\n",
"{% endfor %}\n",
"\"\"\"\n",
"items=[{\"question\": \"foo\", \"answer\": \"bar\"},{\"question\": \"1\", \"answer\": \"2\"}]\n",
"jinja2_prompt = PromptTemplate(\n",
" input_variables=[\"items\"], \n",
" template=template,\n",
" template_format=\"jinja2\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "ba8aabd3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\n\\nQuestion: foo\\nAnswer: bar\\n\\nQuestion: 1\\nAnswer: 2\\n'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jinja2_prompt.format(items=items)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "1492b49d", "id": "1492b49d",
@ -602,7 +655,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.6" "version": "3.10.8"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -2,15 +2,30 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import yaml import yaml
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.formatting import formatter from langchain.formatting import formatter
DEFAULT_FORMATTER_MAPPING = {
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
try:
from jinja2 import Template
except ImportError:
raise ValueError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
return Template(template).render(**kwargs)
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format, "f-string": formatter.format,
"jinja2": jinja2_formatter,
} }

View File

@ -39,7 +39,7 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
"""A prompt template string to put before the examples.""" """A prompt template string to put before the examples."""
template_format: str = "f-string" template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@root_validator(pre=True) @root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict: def check_examples_and_selector(cls, values: Dict) -> Dict:

View File

@ -29,7 +29,7 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
"""The prompt template.""" """The prompt template."""
template_format: str = "f-string" template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

4
poetry.lock generated
View File

@ -2338,13 +2338,13 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker
testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
[extras] [extras]
all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch"] all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2"]
llms = ["manifest-ml", "torch", "transformers"] llms = ["manifest-ml", "torch", "transformers"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "d9ed2c5e1b2c51d7f8a9f74c858ab5058db14b7d6ca542777d7ecd07ccef5ee8" content-hash = "4099d88111a7f46857283d91c66fdfa7ed567461e8023509f2da86736a7ae24a"
[metadata.files] [metadata.files]
anyio = [ anyio = [

View File

@ -22,7 +22,8 @@ spacy = {version = "^3", optional = true}
nltk = {version = "^3", optional = true} nltk = {version = "^3", optional = true}
transformers = {version = "^4", optional = true} transformers = {version = "^4", optional = true}
beautifulsoup4 = {version = "^4", optional = true} beautifulsoup4 = {version = "^4", optional = true}
torch = {version = "^1.13.1", optional = true} torch = {version = "^1", optional = true}
jinja2 = {version = "^3", optional = true}
tiktoken = {version = "^0", optional = true, python="^3.9"} tiktoken = {version = "^0", optional = true, python="^3.9"}
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
@ -51,7 +52,7 @@ playwright = "^1.28.0"
[tool.poetry.extras] [tool.poetry.extras]
llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch"] all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2"]
[tool.isort] [tool.isort]
profile = "black" profile = "black"