diff --git a/docs/examples/prompts/prompt_management.ipynb b/docs/examples/prompts/prompt_management.ipynb index 5d0b4731..c10ac723 100644 --- a/docs/examples/prompts/prompt_management.ipynb +++ b/docs/examples/prompts/prompt_management.ipynb @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 11, "id": "094229f4", "metadata": {}, "outputs": [], @@ -151,6 +151,59 @@ "multiple_input_prompt.format(adjective=\"funny\", content=\"chickens\")" ] }, + { + "cell_type": "markdown", + "id": "b2dd6154", + "metadata": {}, + "source": [ + "## Alternative formats\n", + "\n", + "This section shows how to use alternative formats besides \"f-string\" to format prompts." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "53b41b6a", + "metadata": {}, + "outputs": [], + "source": [ + "# Jinja2\n", + "template = \"\"\"\n", + "{% for item in items %}\n", + "Question: {{ item.question }}\n", + "Answer: {{ item.answer }}\n", + "{% endfor %}\n", + "\"\"\"\n", + "items=[{\"question\": \"foo\", \"answer\": \"bar\"},{\"question\": \"1\", \"answer\": \"2\"}]\n", + "jinja2_prompt = PromptTemplate(\n", + " input_variables=[\"items\"], \n", + " template=template,\n", + " template_format=\"jinja2\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "ba8aabd3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\n\\nQuestion: foo\\nAnswer: bar\\n\\nQuestion: 1\\nAnswer: 2\\n'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jinja2_prompt.format(items=items)" + ] + }, { "cell_type": "markdown", "id": "1492b49d", @@ -602,7 +655,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.10.8" } }, "nbformat": 4, diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index b5ba37eb..c7b708a3 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -2,15 +2,30 @@ import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import yaml from pydantic import BaseModel, Extra, root_validator from langchain.formatting import formatter -DEFAULT_FORMATTER_MAPPING = { + +def jinja2_formatter(template: str, **kwargs: Any) -> str: + """Format a template using jinja2.""" + try: + from jinja2 import Template + except ImportError: + raise ValueError( + "jinja2 not installed, which is needed to use the jinja2_formatter. " + "Please install it with `pip install jinja2`." + ) + + return Template(template).render(**kwargs) + + +DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { "f-string": formatter.format, + "jinja2": jinja2_formatter, } diff --git a/langchain/prompts/few_shot.py b/langchain/prompts/few_shot.py index 98dffd65..1baab6fa 100644 --- a/langchain/prompts/few_shot.py +++ b/langchain/prompts/few_shot.py @@ -39,7 +39,7 @@ class FewShotPromptTemplate(BasePromptTemplate, BaseModel): """A prompt template string to put before the examples.""" template_format: str = "f-string" - """The format of the prompt template. Options are: 'f-string'.""" + """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" @root_validator(pre=True) def check_examples_and_selector(cls, values: Dict) -> Dict: diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index 4a24c2de..cccbbd2c 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -29,7 +29,7 @@ class PromptTemplate(BasePromptTemplate, BaseModel): """The prompt template.""" template_format: str = "f-string" - """The format of the prompt template. Options are: 'f-string'.""" + """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" class Config: """Configuration for this pydantic object.""" diff --git a/poetry.lock b/poetry.lock index 9dcbd09b..ded41a4a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2338,13 +2338,13 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch"] +all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2"] llms = ["manifest-ml", "torch", "transformers"] [metadata] lock-version = "1.1" python-versions = ">=3.8.1,<4.0" -content-hash = "d9ed2c5e1b2c51d7f8a9f74c858ab5058db14b7d6ca542777d7ecd07ccef5ee8" +content-hash = "4099d88111a7f46857283d91c66fdfa7ed567461e8023509f2da86736a7ae24a" [metadata.files] anyio = [ diff --git a/pyproject.toml b/pyproject.toml index 4d97caa7..a20b57c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,8 @@ spacy = {version = "^3", optional = true} nltk = {version = "^3", optional = true} transformers = {version = "^4", optional = true} beautifulsoup4 = {version = "^4", optional = true} -torch = {version = "^1.13.1", optional = true} +torch = {version = "^1", optional = true} +jinja2 = {version = "^3", optional = true} tiktoken = {version = "^0", optional = true, python="^3.9"} [tool.poetry.group.test.dependencies] @@ -51,7 +52,7 @@ playwright = "^1.28.0" [tool.poetry.extras] llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] -all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch"] +all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2"] [tool.isort] profile = "black"