core[minor]: Add aformat_messages to FewShotChatMessagePromptTemplate and ChatPromptTemplate (#19648)

Needed since the example selector may use a vector store.
pull/19537/head
Christophe Bornet 6 months ago committed by GitHub
parent 5f814820f6
commit 6b2b511f68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -66,6 +66,17 @@ class BaseMessagePromptTemplate(Serializable, ABC):
List of BaseMessages. List of BaseMessages.
""" """
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return self.format_messages(**kwargs)
@property @property
@abstractmethod @abstractmethod
def input_variables(self) -> List[str]: def input_variables(self) -> List[str]:
@ -594,6 +605,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages.""" """Format kwargs into a list of messages."""
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages."""
return self.format_messages(**kwargs)
def pretty_repr(self, html: bool = False) -> str: def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation.""" """Human-readable representation."""
raise NotImplementedError raise NotImplementedError
@ -901,19 +916,31 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
partial_variables=partial_vars, partial_variables=partial_vars,
) )
def format(self, **kwargs: Any) -> str: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the chat template into a string. """Format the chat template into a list of finalized messages.
Args: Args:
**kwargs: keyword arguments to use for filling in template variables **kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template. in all the template messages in this chat template.
Returns: Returns:
formatted string list of formatted messages
""" """
return self.format_prompt(**kwargs).to_string() kwargs = self._merge_partial_and_user_variables(**kwargs)
result = []
for message_template in self.messages:
if isinstance(message_template, BaseMessage):
result.extend([message_template])
elif isinstance(
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
):
message = message_template.format_messages(**kwargs)
result.extend(message)
else:
raise ValueError(f"Unexpected input: {message_template}")
return result
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the chat template into a list of finalized messages. """Format the chat template into a list of finalized messages.
Args: Args:
@ -931,7 +958,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
elif isinstance( elif isinstance(
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate) message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
): ):
message = message_template.format_messages(**kwargs) message = await message_template.aformat_messages(**kwargs)
result.extend(message) result.extend(message)
else: else:
raise ValueError(f"Unexpected input: {message_template}") raise ValueError(f"Unexpected input: {message_template}")

@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.prompts.chat import ( from langchain_core.prompts.chat import (
BaseChatPromptTemplate, BaseChatPromptTemplate,
@ -27,7 +28,7 @@ class _FewShotPromptTemplateMixin(BaseModel):
"""Examples to format into the prompt. """Examples to format into the prompt.
Either this or example_selector should be provided.""" Either this or example_selector should be provided."""
example_selector: Any = None example_selector: Optional[BaseExampleSelector] = None
"""ExampleSelector to choose the examples to format into the prompt. """ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided.""" Either this or examples should be provided."""
@ -72,6 +73,24 @@ class _FewShotPromptTemplateMixin(BaseModel):
"One of 'examples' and 'example_selector' should be provided" "One of 'examples' and 'example_selector' should be provided"
) )
async def _aget_examples(self, **kwargs: Any) -> List[dict]:
"""Get the examples to use for formatting the prompt.
Args:
**kwargs: Keyword arguments to be passed to the example selector.
Returns:
List of examples.
"""
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return await self.example_selector.aselect_examples(kwargs)
else:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
@ -325,6 +344,28 @@ class FewShotChatMessagePromptTemplate(
] ]
return messages return messages
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages.
Args:
**kwargs: keyword arguments to use for filling in templates in messages.
Returns:
A list of formatted messages with all template variables filled in.
"""
# Get the examples to use.
examples = await self._aget_examples(**kwargs)
examples = [
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format the examples.
messages = [
message
for example in examples
for message in await self.example_prompt.aformat_messages(**example)
]
return messages
def format(self, **kwargs: Any) -> str: def format(self, **kwargs: Any) -> str:
"""Format the prompt with inputs generating a string. """Format the prompt with inputs generating a string.

@ -308,7 +308,7 @@ def test_prompt_jinja2_extra_input_variables(
).input_variables == ["bar", "foo"] ).input_variables == ["bar", "foo"]
def test_few_shot_chat_message_prompt_template() -> None: async def test_few_shot_chat_message_prompt_template() -> None:
"""Tests for few shot chat message template.""" """Tests for few shot chat message template."""
examples = [ examples = [
{"input": "2+2", "output": "4"}, {"input": "2+2", "output": "4"},
@ -333,8 +333,7 @@ def test_few_shot_chat_message_prompt_template() -> None:
+ HumanMessagePromptTemplate.from_template("{input}") + HumanMessagePromptTemplate.from_template("{input}")
) )
messages = final_prompt.format_messages(input="100 + 1") expected = [
assert messages == [
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}), SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
HumanMessage(content="2+2", additional_kwargs={}, example=False), HumanMessage(content="2+2", additional_kwargs={}, example=False),
AIMessage(content="4", additional_kwargs={}, example=False), AIMessage(content="4", additional_kwargs={}, example=False),
@ -343,6 +342,11 @@ def test_few_shot_chat_message_prompt_template() -> None:
HumanMessage(content="100 + 1", additional_kwargs={}, example=False), HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
] ]
messages = final_prompt.format_messages(input="100 + 1")
assert messages == expected
messages = await final_prompt.aformat_messages(input="100 + 1")
assert messages == expected
class AsIsSelector(BaseExampleSelector): class AsIsSelector(BaseExampleSelector):
"""An example selector for testing purposes. """An example selector for testing purposes.
@ -355,11 +359,9 @@ class AsIsSelector(BaseExampleSelector):
self.examples = examples self.examples = examples
def add_example(self, example: Dict[str, str]) -> Any: def add_example(self, example: Dict[str, str]) -> Any:
"""Adds an example to the selector.""" raise NotImplementedError
raise NotImplementedError()
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on the inputs."""
return list(self.examples) return list(self.examples)
@ -387,8 +389,63 @@ def test_few_shot_chat_message_prompt_template_with_selector() -> None:
+ few_shot_prompt + few_shot_prompt
+ HumanMessagePromptTemplate.from_template("{input}") + HumanMessagePromptTemplate.from_template("{input}")
) )
expected = [
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
HumanMessage(content="2+2", additional_kwargs={}, example=False),
AIMessage(content="4", additional_kwargs={}, example=False),
HumanMessage(content="2+3", additional_kwargs={}, example=False),
AIMessage(content="5", additional_kwargs={}, example=False),
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
]
messages = final_prompt.format_messages(input="100 + 1") messages = final_prompt.format_messages(input="100 + 1")
assert messages == [ assert messages == expected
class AsyncAsIsSelector(BaseExampleSelector):
"""An example selector for testing purposes.
This selector returns the examples as-is.
"""
def __init__(self, examples: Sequence[Dict[str, str]]) -> None:
"""Initializes the selector."""
self.examples = examples
def add_example(self, example: Dict[str, str]) -> Any:
raise NotImplementedError
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
raise NotImplementedError
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
return list(self.examples)
async def test_few_shot_chat_message_prompt_template_with_selector_async() -> None:
"""Tests for few shot chat message template with an async example selector."""
examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
]
example_selector = AsyncAsIsSelector(examples)
example_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template("{input}"),
AIMessagePromptTemplate.from_template("{output}"),
]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
input_variables=["input"],
example_prompt=example_prompt,
example_selector=example_selector,
)
final_prompt: ChatPromptTemplate = (
SystemMessagePromptTemplate.from_template("You are a helpful AI Assistant")
+ few_shot_prompt
+ HumanMessagePromptTemplate.from_template("{input}")
)
expected = [
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}), SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
HumanMessage(content="2+2", additional_kwargs={}, example=False), HumanMessage(content="2+2", additional_kwargs={}, example=False),
AIMessage(content="4", additional_kwargs={}, example=False), AIMessage(content="4", additional_kwargs={}, example=False),
@ -396,3 +453,5 @@ def test_few_shot_chat_message_prompt_template_with_selector() -> None:
AIMessage(content="5", additional_kwargs={}, example=False), AIMessage(content="5", additional_kwargs={}, example=False),
HumanMessage(content="100 + 1", additional_kwargs={}, example=False), HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
] ]
messages = await final_prompt.aformat_messages(input="100 + 1")
assert messages == expected

Loading…
Cancel
Save