diff --git a/libs/langchain/langchain/prompts/base.py b/libs/langchain/langchain/prompts/base.py index f2cbd6dad4..d71a264988 100644 --- a/libs/langchain/langchain/prompts/base.py +++ b/libs/langchain/langchain/prompts/base.py @@ -3,6 +3,7 @@ from __future__ import annotations import warnings from abc import ABC +from string import Formatter from typing import Any, Callable, Dict, List, Literal, Set from langchain.schema.messages import BaseMessage, HumanMessage @@ -99,6 +100,20 @@ def check_valid_template( ) +def get_template_variables(template: str, template_format: str) -> List[str]: + if template_format == "jinja2": + # Get the variables for the template + input_variables = _get_jinja2_variables_from_template(template) + elif template_format == "f-string": + input_variables = { + v for _, v, _, _ in Formatter().parse(template) if v is not None + } + else: + raise ValueError(f"Unsupported template format: {template_format}") + + return sorted(input_variables) + + class StringPromptValue(PromptValue): """String prompt value.""" diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index 8c301c109a..e6f99b23a6 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -382,6 +382,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate): """List of input variables in template messages. Used for validation.""" messages: List[MessageLike] """List of messages consisting of either message prompt templates or messages.""" + validate_template: bool = False + """Whether or not to try validating the template.""" def __add__(self, other: Any) -> ChatPromptTemplate: """Combine two prompt templates. @@ -432,7 +434,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): input_types[message.variable_name] = List[AnyMessage] if "partial_variables" in values: input_vars = input_vars - set(values["partial_variables"]) - if "input_variables" in values: + if "input_variables" in values and values.get("validate_template"): if input_vars != set(values["input_variables"]): raise ValueError( "Got mismatched input_variables. " diff --git a/libs/langchain/langchain/prompts/few_shot.py b/libs/langchain/langchain/prompts/few_shot.py index e8fa1b2447..4016336d70 100644 --- a/libs/langchain/langchain/prompts/few_shot.py +++ b/libs/langchain/langchain/prompts/few_shot.py @@ -2,12 +2,13 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from langchain.prompts.base import ( DEFAULT_FORMATTER_MAPPING, StringPromptTemplate, check_valid_template, + get_template_variables, ) from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate from langchain.prompts.example_selector.base import BaseExampleSelector @@ -77,7 +78,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): """Return whether or not the class is serializable.""" return False - validate_template: bool = True + validate_template: bool = False """Whether or not to try validating the template.""" input_variables: List[str] @@ -95,7 +96,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): prefix: str = "" """A prompt template string to put before the examples.""" - template_format: str = "f-string" + template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" @root_validator() @@ -107,6 +108,14 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): values["template_format"], values["input_variables"] + list(values["partial_variables"]), ) + elif values.get("template_format"): + values["input_variables"] = [ + var + for var in get_template_variables( + values["prefix"] + values["suffix"], values["template_format"] + ) + if var not in values["partial_variables"] + ] return values class Config: diff --git a/libs/langchain/langchain/prompts/few_shot_with_templates.py b/libs/langchain/langchain/prompts/few_shot_with_templates.py index 1e34a0f5a8..bee5fe71e0 100644 --- a/libs/langchain/langchain/prompts/few_shot_with_templates.py +++ b/libs/langchain/langchain/prompts/few_shot_with_templates.py @@ -37,7 +37,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): template_format: str = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - validate_template: bool = True + validate_template: bool = False """Whether or not to try validating the template.""" @root_validator(pre=True) @@ -72,6 +72,12 @@ class FewShotPromptWithTemplates(StringPromptTemplate): f"Got input_variables={input_variables}, but based on " f"prefix/suffix expected {expected_input_variables}" ) + else: + values["input_variables"] = sorted( + set(values["suffix"].input_variables) + | set(values["prefix"].input_variables if values["prefix"] else []) + - set(values["partial_variables"]) + ) return values class Config: diff --git a/libs/langchain/langchain/prompts/prompt.py b/libs/langchain/langchain/prompts/prompt.py index 65583af824..27ee992f8e 100644 --- a/libs/langchain/langchain/prompts/prompt.py +++ b/libs/langchain/langchain/prompts/prompt.py @@ -2,14 +2,13 @@ from __future__ import annotations from pathlib import Path -from string import Formatter -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from langchain.prompts.base import ( DEFAULT_FORMATTER_MAPPING, StringPromptTemplate, - _get_jinja2_variables_from_template, check_valid_template, + get_template_variables, ) from langchain.pydantic_v1 import root_validator @@ -53,10 +52,10 @@ class PromptTemplate(StringPromptTemplate): template: str """The prompt template.""" - template_format: str = "f-string" + template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" - validate_template: bool = True + validate_template: bool = False """Whether or not to try validating the template.""" def __add__(self, other: Any) -> PromptTemplate: @@ -127,6 +126,14 @@ class PromptTemplate(StringPromptTemplate): check_valid_template( values["template"], values["template_format"], all_inputs ) + elif values.get("template_format"): + values["input_variables"] = [ + var + for var in get_template_variables( + values["template"], values["template_format"] + ) + if var not in values["partial_variables"] + ] return values @classmethod @@ -202,25 +209,17 @@ class PromptTemplate(StringPromptTemplate): Returns: The prompt template loaded from the template. """ - if template_format == "jinja2": - # Get the variables for the template - input_variables = _get_jinja2_variables_from_template(template) - elif template_format == "f-string": - input_variables = { - v for _, v, _, _ in Formatter().parse(template) if v is not None - } - else: - raise ValueError(f"Unsupported template format: {template_format}") + input_variables = get_template_variables(template, template_format) _partial_variables = partial_variables or {} if _partial_variables: - input_variables = { + input_variables = [ var for var in input_variables if var not in _partial_variables - } + ] return cls( - input_variables=sorted(input_variables), + input_variables=input_variables, template=template, template_format=template_format, partial_variables=_partial_variables, diff --git a/libs/langchain/langchain/tools/sql_database/tool.py b/libs/langchain/langchain/tools/sql_database/tool.py index 99289e4a97..5dfe8f680f 100644 --- a/libs/langchain/langchain/tools/sql_database/tool.py +++ b/libs/langchain/langchain/tools/sql_database/tool.py @@ -97,11 +97,11 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): values["llm_chain"] = LLMChain( llm=values.get("llm"), prompt=PromptTemplate( - template=QUERY_CHECKER, input_variables=["query", "dialect"] + template=QUERY_CHECKER, input_variables=["dialect", "query"] ), ) - if values["llm_chain"].prompt.input_variables != ["query", "dialect"]: + if values["llm_chain"].prompt.input_variables != ["dialect", "query"]: raise ValueError( "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']" ) diff --git a/libs/langchain/tests/unit_tests/examples/prompt_with_output_parser.json b/libs/langchain/tests/unit_tests/examples/prompt_with_output_parser.json index 0f313b4507..fad5ba320d 100644 --- a/libs/langchain/tests/unit_tests/examples/prompt_with_output_parser.json +++ b/libs/langchain/tests/unit_tests/examples/prompt_with_output_parser.json @@ -15,6 +15,5 @@ "partial_variables": {}, "template": "Given the following question and student answer, provide a correct answer and score the student answer.\nQuestion: {question}\nStudent Answer: {student_answer}\nCorrect Answer:", "template_format": "f-string", - "validate_template": true, "_type": "prompt" -} \ No newline at end of file +} diff --git a/libs/langchain/tests/unit_tests/prompts/test_chat.py b/libs/langchain/tests/unit_tests/prompts/test_chat.py index 7b089c687c..9ad0ba0258 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_chat.py +++ b/libs/langchain/tests/unit_tests/prompts/test_chat.py @@ -186,13 +186,24 @@ def test_chat_prompt_template_with_messages() -> None: def test_chat_invalid_input_variables_extra() -> None: messages = [HumanMessage(content="foo")] with pytest.raises(ValueError): - ChatPromptTemplate(messages=messages, input_variables=["foo"]) + ChatPromptTemplate( + messages=messages, input_variables=["foo"], validate_template=True + ) + assert ( + ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables + == [] + ) def test_chat_invalid_input_variables_missing() -> None: messages = [HumanMessagePromptTemplate.from_template("{foo}")] with pytest.raises(ValueError): - ChatPromptTemplate(messages=messages, input_variables=[]) + ChatPromptTemplate( + messages=messages, input_variables=[], validate_template=True + ) + assert ChatPromptTemplate( + messages=messages, input_variables=[] + ).input_variables == ["foo"] def test_infer_variables() -> None: diff --git a/libs/langchain/tests/unit_tests/prompts/test_few_shot.py b/libs/langchain/tests/unit_tests/prompts/test_few_shot.py index f1e5a44d9b..69e9f487b2 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_few_shot.py +++ b/libs/langchain/tests/unit_tests/prompts/test_few_shot.py @@ -67,7 +67,14 @@ def test_prompt_missing_input_variables() -> None: suffix=template, examples=[], example_prompt=EXAMPLE_PROMPT, + validate_template=True, ) + assert FewShotPromptTemplate( + input_variables=[], + suffix=template, + examples=[], + example_prompt=EXAMPLE_PROMPT, + ).input_variables == ["foo"] # Test when missing in prefix template = "This is a {foo} test." @@ -78,7 +85,15 @@ def test_prompt_missing_input_variables() -> None: examples=[], prefix=template, example_prompt=EXAMPLE_PROMPT, + validate_template=True, ) + assert FewShotPromptTemplate( + input_variables=[], + suffix="foo", + examples=[], + prefix=template, + example_prompt=EXAMPLE_PROMPT, + ).input_variables == ["foo"] def test_prompt_extra_input_variables() -> None: @@ -91,7 +106,14 @@ def test_prompt_extra_input_variables() -> None: suffix=template, examples=[], example_prompt=EXAMPLE_PROMPT, + validate_template=True, ) + assert FewShotPromptTemplate( + input_variables=input_variables, + suffix=template, + examples=[], + example_prompt=EXAMPLE_PROMPT, + ).input_variables == ["foo"] def test_few_shot_functionality() -> None: @@ -248,7 +270,15 @@ def test_prompt_jinja2_missing_input_variables( examples=example_jinja2_prompt[1], example_prompt=example_jinja2_prompt[0], template_format="jinja2", + validate_template=True, ) + assert FewShotPromptTemplate( + input_variables=[], + suffix=suffix, + examples=example_jinja2_prompt[1], + example_prompt=example_jinja2_prompt[0], + template_format="jinja2", + ).input_variables == ["bar"] # Test when missing in prefix with pytest.warns(UserWarning): @@ -259,7 +289,16 @@ def test_prompt_jinja2_missing_input_variables( examples=example_jinja2_prompt[1], example_prompt=example_jinja2_prompt[0], template_format="jinja2", + validate_template=True, ) + assert FewShotPromptTemplate( + input_variables=["bar"], + suffix=suffix, + prefix=prefix, + examples=example_jinja2_prompt[1], + example_prompt=example_jinja2_prompt[0], + template_format="jinja2", + ).input_variables == ["bar", "foo"] @pytest.mark.requires("jinja2") @@ -277,7 +316,16 @@ def test_prompt_jinja2_extra_input_variables( examples=example_jinja2_prompt[1], example_prompt=example_jinja2_prompt[0], template_format="jinja2", + validate_template=True, ) + assert FewShotPromptTemplate( + input_variables=["bar", "foo", "extra", "thing"], + suffix=suffix, + prefix=prefix, + examples=example_jinja2_prompt[1], + example_prompt=example_jinja2_prompt[0], + template_format="jinja2", + ).input_variables == ["bar", "foo"] def test_few_shot_chat_message_prompt_template() -> None: diff --git a/libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py b/libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py index c5c10d743e..bf91eaaeb0 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py +++ b/libs/langchain/tests/unit_tests/prompts/test_few_shot_with_templates.py @@ -1,5 +1,7 @@ """Test few shot prompt template.""" +import pytest + from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates from langchain.prompts.prompt import PromptTemplate @@ -38,3 +40,37 @@ def test_prompttemplate_prefix_suffix() -> None: "Now you try to talk about party." ) assert output == expected_output + + +def test_prompttemplate_validation() -> None: + """Test that few shot works when prefix and suffix are PromptTemplates.""" + prefix = PromptTemplate( + input_variables=["content"], template="This is a test about {content}." + ) + suffix = PromptTemplate( + input_variables=["new_content"], + template="Now you try to talk about {new_content}.", + ) + + examples = [ + {"question": "foo", "answer": "bar"}, + {"question": "baz", "answer": "foo"}, + ] + with pytest.raises(ValueError): + FewShotPromptWithTemplates( + suffix=suffix, + prefix=prefix, + input_variables=[], + examples=examples, + example_prompt=EXAMPLE_PROMPT, + example_separator="\n", + validate_template=True, + ) + assert FewShotPromptWithTemplates( + suffix=suffix, + prefix=prefix, + input_variables=[], + examples=examples, + example_prompt=EXAMPLE_PROMPT, + example_separator="\n", + ).input_variables == ["content", "new_content"] diff --git a/libs/langchain/tests/unit_tests/prompts/test_prompt.py b/libs/langchain/tests/unit_tests/prompts/test_prompt.py index 87221df45e..7966a3067b 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_prompt.py +++ b/libs/langchain/tests/unit_tests/prompts/test_prompt.py @@ -39,7 +39,12 @@ def test_prompt_missing_input_variables() -> None: template = "This is a {foo} test." input_variables: list = [] with pytest.raises(ValueError): - PromptTemplate(input_variables=input_variables, template=template) + PromptTemplate( + input_variables=input_variables, template=template, validate_template=True + ) + assert PromptTemplate( + input_variables=input_variables, template=template + ).input_variables == ["foo"] def test_prompt_extra_input_variables() -> None: @@ -47,7 +52,12 @@ def test_prompt_extra_input_variables() -> None: template = "This is a {foo} test." input_variables = ["foo", "bar"] with pytest.raises(ValueError): - PromptTemplate(input_variables=input_variables, template=template) + PromptTemplate( + input_variables=input_variables, template=template, validate_template=True + ) + assert PromptTemplate( + input_variables=input_variables, template=template + ).input_variables == ["foo"] def test_prompt_wrong_input_variables() -> None: @@ -55,7 +65,12 @@ def test_prompt_wrong_input_variables() -> None: template = "This is a {foo} test." input_variables = ["bar"] with pytest.raises(ValueError): - PromptTemplate(input_variables=input_variables, template=template) + PromptTemplate( + input_variables=input_variables, template=template, validate_template=True + ) + assert PromptTemplate( + input_variables=input_variables, template=template + ).input_variables == ["foo"] def test_prompt_from_examples_valid() -> None: @@ -229,8 +244,14 @@ def test_prompt_jinja2_missing_input_variables() -> None: input_variables: list = [] with pytest.warns(UserWarning): PromptTemplate( - input_variables=input_variables, template=template, template_format="jinja2" + input_variables=input_variables, + template=template, + template_format="jinja2", + validate_template=True, ) + assert PromptTemplate( + input_variables=input_variables, template=template, template_format="jinja2" + ).input_variables == ["foo"] @pytest.mark.requires("jinja2") @@ -240,8 +261,14 @@ def test_prompt_jinja2_extra_input_variables() -> None: input_variables = ["foo", "bar"] with pytest.warns(UserWarning): PromptTemplate( - input_variables=input_variables, template=template, template_format="jinja2" + input_variables=input_variables, + template=template, + template_format="jinja2", + validate_template=True, ) + assert PromptTemplate( + input_variables=input_variables, template=template, template_format="jinja2" + ).input_variables == ["foo"] @pytest.mark.requires("jinja2") @@ -251,5 +278,11 @@ def test_prompt_jinja2_wrong_input_variables() -> None: input_variables = ["bar"] with pytest.warns(UserWarning): PromptTemplate( - input_variables=input_variables, template=template, template_format="jinja2" + input_variables=input_variables, + template=template, + template_format="jinja2", + validate_template=True, ) + assert PromptTemplate( + input_variables=input_variables, template=template, template_format="jinja2" + ).input_variables == ["foo"]