diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 12f8d21baa..cd27e0ef5c 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -66,6 +66,17 @@ class BaseMessagePromptTemplate(Serializable, ABC): List of BaseMessages. """ + async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format messages from kwargs. Should return a list of BaseMessages. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + List of BaseMessages. + """ + return self.format_messages(**kwargs) + @property @abstractmethod def input_variables(self) -> List[str]: @@ -594,6 +605,10 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): def format_messages(self, **kwargs: Any) -> List[BaseMessage]: """Format kwargs into a list of messages.""" + async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format kwargs into a list of messages.""" + return self.format_messages(**kwargs) + def pretty_repr(self, html: bool = False) -> str: """Human-readable representation.""" raise NotImplementedError @@ -901,19 +916,31 @@ class ChatPromptTemplate(BaseChatPromptTemplate): partial_variables=partial_vars, ) - def format(self, **kwargs: Any) -> str: - """Format the chat template into a string. + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format the chat template into a list of finalized messages. Args: **kwargs: keyword arguments to use for filling in template variables in all the template messages in this chat template. Returns: - formatted string + list of formatted messages """ - return self.format_prompt(**kwargs).to_string() + kwargs = self._merge_partial_and_user_variables(**kwargs) + result = [] + for message_template in self.messages: + if isinstance(message_template, BaseMessage): + result.extend([message_template]) + elif isinstance( + message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate) + ): + message = message_template.format_messages(**kwargs) + result.extend(message) + else: + raise ValueError(f"Unexpected input: {message_template}") + return result - def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: """Format the chat template into a list of finalized messages. Args: @@ -931,7 +958,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): elif isinstance( message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate) ): - message = message_template.format_messages(**kwargs) + message = await message_template.aformat_messages(**kwargs) result.extend(message) else: raise ValueError(f"Unexpected input: {message_template}") diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py index e03513fb86..4c5ee2c5dc 100644 --- a/libs/core/langchain_core/prompts/few_shot.py +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -5,6 +5,7 @@ from __future__ import annotations from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Union +from langchain_core.example_selectors import BaseExampleSelector from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.prompts.chat import ( BaseChatPromptTemplate, @@ -27,7 +28,7 @@ class _FewShotPromptTemplateMixin(BaseModel): """Examples to format into the prompt. Either this or example_selector should be provided.""" - example_selector: Any = None + example_selector: Optional[BaseExampleSelector] = None """ExampleSelector to choose the examples to format into the prompt. Either this or examples should be provided.""" @@ -72,6 +73,24 @@ class _FewShotPromptTemplateMixin(BaseModel): "One of 'examples' and 'example_selector' should be provided" ) + async def _aget_examples(self, **kwargs: Any) -> List[dict]: + """Get the examples to use for formatting the prompt. + + Args: + **kwargs: Keyword arguments to be passed to the example selector. + + Returns: + List of examples. + """ + if self.examples is not None: + return self.examples + elif self.example_selector is not None: + return await self.example_selector.aselect_examples(kwargs) + else: + raise ValueError( + "One of 'examples' and 'example_selector' should be provided" + ) + class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): """Prompt template that contains few shot examples.""" @@ -325,6 +344,28 @@ class FewShotChatMessagePromptTemplate( ] return messages + async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format kwargs into a list of messages. + + Args: + **kwargs: keyword arguments to use for filling in templates in messages. + + Returns: + A list of formatted messages with all template variables filled in. + """ + # Get the examples to use. + examples = await self._aget_examples(**kwargs) + examples = [ + {k: e[k] for k in self.example_prompt.input_variables} for e in examples + ] + # Format the examples. + messages = [ + message + for example in examples + for message in await self.example_prompt.aformat_messages(**example) + ] + return messages + def format(self, **kwargs: Any) -> str: """Format the prompt with inputs generating a string. diff --git a/libs/core/tests/unit_tests/prompts/test_few_shot.py b/libs/core/tests/unit_tests/prompts/test_few_shot.py index 7b4bc2378e..4129c3fe53 100644 --- a/libs/core/tests/unit_tests/prompts/test_few_shot.py +++ b/libs/core/tests/unit_tests/prompts/test_few_shot.py @@ -308,7 +308,7 @@ def test_prompt_jinja2_extra_input_variables( ).input_variables == ["bar", "foo"] -def test_few_shot_chat_message_prompt_template() -> None: +async def test_few_shot_chat_message_prompt_template() -> None: """Tests for few shot chat message template.""" examples = [ {"input": "2+2", "output": "4"}, @@ -333,8 +333,7 @@ def test_few_shot_chat_message_prompt_template() -> None: + HumanMessagePromptTemplate.from_template("{input}") ) - messages = final_prompt.format_messages(input="100 + 1") - assert messages == [ + expected = [ SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}), HumanMessage(content="2+2", additional_kwargs={}, example=False), AIMessage(content="4", additional_kwargs={}, example=False), @@ -343,6 +342,11 @@ def test_few_shot_chat_message_prompt_template() -> None: HumanMessage(content="100 + 1", additional_kwargs={}, example=False), ] + messages = final_prompt.format_messages(input="100 + 1") + assert messages == expected + messages = await final_prompt.aformat_messages(input="100 + 1") + assert messages == expected + class AsIsSelector(BaseExampleSelector): """An example selector for testing purposes. @@ -355,11 +359,9 @@ class AsIsSelector(BaseExampleSelector): self.examples = examples def add_example(self, example: Dict[str, str]) -> Any: - """Adds an example to the selector.""" - raise NotImplementedError() + raise NotImplementedError def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: - """Select which examples to use based on the inputs.""" return list(self.examples) @@ -387,8 +389,63 @@ def test_few_shot_chat_message_prompt_template_with_selector() -> None: + few_shot_prompt + HumanMessagePromptTemplate.from_template("{input}") ) + expected = [ + SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}), + HumanMessage(content="2+2", additional_kwargs={}, example=False), + AIMessage(content="4", additional_kwargs={}, example=False), + HumanMessage(content="2+3", additional_kwargs={}, example=False), + AIMessage(content="5", additional_kwargs={}, example=False), + HumanMessage(content="100 + 1", additional_kwargs={}, example=False), + ] messages = final_prompt.format_messages(input="100 + 1") - assert messages == [ + assert messages == expected + + +class AsyncAsIsSelector(BaseExampleSelector): + """An example selector for testing purposes. + + This selector returns the examples as-is. + """ + + def __init__(self, examples: Sequence[Dict[str, str]]) -> None: + """Initializes the selector.""" + self.examples = examples + + def add_example(self, example: Dict[str, str]) -> Any: + raise NotImplementedError + + def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + raise NotImplementedError + + async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: + return list(self.examples) + + +async def test_few_shot_chat_message_prompt_template_with_selector_async() -> None: + """Tests for few shot chat message template with an async example selector.""" + examples = [ + {"input": "2+2", "output": "4"}, + {"input": "2+3", "output": "5"}, + ] + example_selector = AsyncAsIsSelector(examples) + example_prompt = ChatPromptTemplate.from_messages( + [ + HumanMessagePromptTemplate.from_template("{input}"), + AIMessagePromptTemplate.from_template("{output}"), + ] + ) + + few_shot_prompt = FewShotChatMessagePromptTemplate( + input_variables=["input"], + example_prompt=example_prompt, + example_selector=example_selector, + ) + final_prompt: ChatPromptTemplate = ( + SystemMessagePromptTemplate.from_template("You are a helpful AI Assistant") + + few_shot_prompt + + HumanMessagePromptTemplate.from_template("{input}") + ) + expected = [ SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}), HumanMessage(content="2+2", additional_kwargs={}, example=False), AIMessage(content="4", additional_kwargs={}, example=False), @@ -396,3 +453,5 @@ def test_few_shot_chat_message_prompt_template_with_selector() -> None: AIMessage(content="5", additional_kwargs={}, example=False), HumanMessage(content="100 + 1", additional_kwargs={}, example=False), ] + messages = await final_prompt.aformat_messages(input="100 + 1") + assert messages == expected