mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
6efd5fa2b9
W.r.t recent changes, ChatPromptTemplate does not accepting partial variables. This PR should fix that issue. Fixes #6431 #### Who can review? @hwchase17 --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
193 lines
6.5 KiB
Python
193 lines
6.5 KiB
Python
from pathlib import Path
|
|
from typing import List
|
|
|
|
import pytest
|
|
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain.prompts.chat import (
|
|
AIMessagePromptTemplate,
|
|
BaseMessagePromptTemplate,
|
|
ChatMessagePromptTemplate,
|
|
ChatPromptTemplate,
|
|
ChatPromptValue,
|
|
HumanMessagePromptTemplate,
|
|
SystemMessagePromptTemplate,
|
|
)
|
|
from langchain.schema import HumanMessage
|
|
|
|
|
|
def create_messages() -> List[BaseMessagePromptTemplate]:
|
|
"""Create messages."""
|
|
system_message_prompt = SystemMessagePromptTemplate(
|
|
prompt=PromptTemplate(
|
|
template="Here's some context: {context}",
|
|
input_variables=["context"],
|
|
)
|
|
)
|
|
human_message_prompt = HumanMessagePromptTemplate(
|
|
prompt=PromptTemplate(
|
|
template="Hello {foo}, I'm {bar}. Thanks for the {context}",
|
|
input_variables=["foo", "bar", "context"],
|
|
)
|
|
)
|
|
ai_message_prompt = AIMessagePromptTemplate(
|
|
prompt=PromptTemplate(
|
|
template="I'm an AI. I'm {foo}. I'm {bar}.",
|
|
input_variables=["foo", "bar"],
|
|
)
|
|
)
|
|
chat_message_prompt = ChatMessagePromptTemplate(
|
|
role="test",
|
|
prompt=PromptTemplate(
|
|
template="I'm a generic message. I'm {foo}. I'm {bar}.",
|
|
input_variables=["foo", "bar"],
|
|
),
|
|
)
|
|
return [
|
|
system_message_prompt,
|
|
human_message_prompt,
|
|
ai_message_prompt,
|
|
chat_message_prompt,
|
|
]
|
|
|
|
|
|
def create_chat_prompt_template() -> ChatPromptTemplate:
|
|
"""Create a chat prompt template."""
|
|
return ChatPromptTemplate(
|
|
input_variables=["foo", "bar", "context"],
|
|
messages=create_messages(),
|
|
)
|
|
|
|
|
|
def test_create_chat_prompt_template_from_template() -> None:
|
|
"""Create a chat prompt template."""
|
|
prompt = ChatPromptTemplate.from_template("hi {foo} {bar}")
|
|
assert prompt.messages == [
|
|
HumanMessagePromptTemplate.from_template("hi {foo} {bar}")
|
|
]
|
|
|
|
|
|
def test_create_chat_prompt_template_from_template_partial() -> None:
|
|
"""Create a chat prompt template with partials."""
|
|
prompt = ChatPromptTemplate.from_template(
|
|
"hi {foo} {bar}", partial_variables={"foo": "jim"}
|
|
)
|
|
expected_prompt = PromptTemplate(
|
|
template="hi {foo} {bar}",
|
|
input_variables=["bar"],
|
|
partial_variables={"foo": "jim"},
|
|
)
|
|
assert len(prompt.messages) == 1
|
|
output_prompt = prompt.messages[0]
|
|
assert isinstance(output_prompt, HumanMessagePromptTemplate)
|
|
assert output_prompt.prompt == expected_prompt
|
|
|
|
|
|
def test_message_prompt_template_from_template_file() -> None:
|
|
expected = ChatMessagePromptTemplate(
|
|
prompt=PromptTemplate(
|
|
template="Question: {question}\nAnswer:", input_variables=["question"]
|
|
),
|
|
role="human",
|
|
)
|
|
actual = ChatMessagePromptTemplate.from_template_file(
|
|
Path(__file__).parent.parent / "data" / "prompt_file.txt",
|
|
["question"],
|
|
role="human",
|
|
)
|
|
assert expected == actual
|
|
|
|
|
|
def test_chat_prompt_template() -> None:
|
|
"""Test chat prompt template."""
|
|
prompt_template = create_chat_prompt_template()
|
|
prompt = prompt_template.format_prompt(foo="foo", bar="bar", context="context")
|
|
assert isinstance(prompt, ChatPromptValue)
|
|
messages = prompt.to_messages()
|
|
assert len(messages) == 4
|
|
assert messages[0].content == "Here's some context: context"
|
|
assert messages[1].content == "Hello foo, I'm bar. Thanks for the context"
|
|
assert messages[2].content == "I'm an AI. I'm foo. I'm bar."
|
|
assert messages[3].content == "I'm a generic message. I'm foo. I'm bar."
|
|
|
|
string = prompt.to_string()
|
|
expected = (
|
|
"System: Here's some context: context\n"
|
|
"Human: Hello foo, I'm bar. Thanks for the context\n"
|
|
"AI: I'm an AI. I'm foo. I'm bar.\n"
|
|
"test: I'm a generic message. I'm foo. I'm bar."
|
|
)
|
|
assert string == expected
|
|
|
|
string = prompt_template.format(foo="foo", bar="bar", context="context")
|
|
assert string == expected
|
|
|
|
|
|
def test_chat_prompt_template_from_messages() -> None:
|
|
"""Test creating a chat prompt template from messages."""
|
|
chat_prompt_template = ChatPromptTemplate.from_messages(create_messages())
|
|
assert sorted(chat_prompt_template.input_variables) == sorted(
|
|
["context", "foo", "bar"]
|
|
)
|
|
assert len(chat_prompt_template.messages) == 4
|
|
|
|
|
|
def test_chat_prompt_template_with_messages() -> None:
|
|
messages = create_messages() + [HumanMessage(content="foo")]
|
|
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
|
|
assert sorted(chat_prompt_template.input_variables) == sorted(
|
|
["context", "foo", "bar"]
|
|
)
|
|
assert len(chat_prompt_template.messages) == 5
|
|
prompt_value = chat_prompt_template.format_prompt(
|
|
context="see", foo="this", bar="magic"
|
|
)
|
|
prompt_value_messages = prompt_value.to_messages()
|
|
assert prompt_value_messages[-1] == HumanMessage(content="foo")
|
|
|
|
|
|
def test_chat_invalid_input_variables_extra() -> None:
|
|
messages = [HumanMessage(content="foo")]
|
|
with pytest.raises(ValueError):
|
|
ChatPromptTemplate(messages=messages, input_variables=["foo"])
|
|
|
|
|
|
def test_chat_invalid_input_variables_missing() -> None:
|
|
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
|
|
with pytest.raises(ValueError):
|
|
ChatPromptTemplate(messages=messages, input_variables=[])
|
|
|
|
|
|
def test_infer_variables() -> None:
|
|
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
|
|
prompt = ChatPromptTemplate(messages=messages)
|
|
assert prompt.input_variables == ["foo"]
|
|
|
|
|
|
def test_chat_valid_with_partial_variables() -> None:
|
|
messages = [
|
|
HumanMessagePromptTemplate.from_template(
|
|
"Do something with {question} using {context} giving it like {formatins}"
|
|
)
|
|
]
|
|
prompt = ChatPromptTemplate(
|
|
messages=messages,
|
|
input_variables=["question", "context"],
|
|
partial_variables={"formatins": "some structure"},
|
|
)
|
|
assert set(prompt.input_variables) == set(["question", "context"])
|
|
assert prompt.partial_variables == {"formatins": "some structure"}
|
|
|
|
|
|
def test_chat_valid_infer_variables() -> None:
|
|
messages = [
|
|
HumanMessagePromptTemplate.from_template(
|
|
"Do something with {question} using {context} giving it like {formatins}"
|
|
)
|
|
]
|
|
prompt = ChatPromptTemplate(
|
|
messages=messages, partial_variables={"formatins": "some structure"}
|
|
)
|
|
assert set(prompt.input_variables) == set(["question", "context"])
|
|
assert prompt.partial_variables == {"formatins": "some structure"}
|