From f43b48aebc1bbfe5495fc3d52a7ead2493706ca0 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 9 Apr 2024 21:59:39 +0200 Subject: [PATCH] core[minor]: Implement aformat_messages for _StringImageMessagePromptTemplate (#20036) --- libs/core/langchain_core/prompts/chat.py | 31 +++++++++ .../tests/unit_tests/prompts/test_chat.py | 63 ++++++++++++++----- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index d03461e7ff..e35050bc64 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -506,6 +506,9 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): """ return [self.format(**kwargs)] + async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: + return [await self.aformat(**kwargs)] + @property def input_variables(self) -> List[str]: """ @@ -546,6 +549,34 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): content=content, additional_kwargs=self.additional_kwargs ) + async def aformat(self, **kwargs: Any) -> BaseMessage: + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ + if isinstance(self.prompt, StringPromptTemplate): + text = await self.prompt.aformat(**kwargs) + return self._msg_class( + content=text, additional_kwargs=self.additional_kwargs + ) + else: + content: List = [] + for prompt in self.prompt: + inputs = {var: kwargs[var] for var in prompt.input_variables} + if isinstance(prompt, StringPromptTemplate): + formatted: Union[str, ImageURL] = await prompt.aformat(**inputs) + content.append({"type": "text", "text": formatted}) + elif isinstance(prompt, ImagePromptTemplate): + formatted = await prompt.aformat(**inputs) + content.append({"type": "image_url", "image_url": formatted}) + return self._msg_class( + content=content, additional_kwargs=self.additional_kwargs + ) + def pretty_repr(self, html: bool = False) -> str: # TODO: Handle partials title = self.__class__.__name__.replace("MessagePromptTemplate", " Message") diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 152c612fe1..67bd4209db 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -143,6 +143,9 @@ async def test_chat_prompt_template(chat_prompt_template: ChatPromptTemplate) -> string = chat_prompt_template.format(foo="foo", bar="bar", context="context") assert string == expected + string = await chat_prompt_template.aformat(foo="foo", bar="bar", context="context") + assert string == expected + def test_chat_prompt_template_from_messages( messages: List[BaseMessagePromptTemplate], @@ -155,7 +158,7 @@ def test_chat_prompt_template_from_messages( assert len(chat_prompt_template.messages) == 4 -def test_chat_prompt_template_from_messages_using_role_strings() -> None: +async def test_chat_prompt_template_from_messages_using_role_strings() -> None: """Test creating a chat prompt template from role string messages.""" template = ChatPromptTemplate.from_messages( [ @@ -166,9 +169,7 @@ def test_chat_prompt_template_from_messages_using_role_strings() -> None: ] ) - messages = template.format_messages(name="Bob", user_input="What is your name?") - - assert messages == [ + expected = [ SystemMessage( content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={} ), @@ -181,6 +182,14 @@ def test_chat_prompt_template_from_messages_using_role_strings() -> None: HumanMessage(content="What is your name?", additional_kwargs={}, example=False), ] + messages = template.format_messages(name="Bob", user_input="What is your name?") + assert messages == expected + + messages = await template.aformat_messages( + name="Bob", user_input="What is your name?" + ) + assert messages == expected + def test_chat_prompt_template_with_messages( messages: List[BaseMessagePromptTemplate], @@ -262,7 +271,7 @@ def test_chat_valid_infer_variables() -> None: assert prompt.partial_variables == {"formatins": "some structure"} -def test_chat_from_role_strings() -> None: +async def test_chat_from_role_strings() -> None: """Test instantiation of chat template from role strings.""" with pytest.warns(LangChainPendingDeprecationWarning): template = ChatPromptTemplate.from_role_strings( @@ -274,14 +283,19 @@ def test_chat_from_role_strings() -> None: ] ) - messages = template.format_messages(question="How are you?", quack="duck") - assert messages == [ + expected = [ ChatMessage(content="You are a bot.", role="system"), ChatMessage(content="hello!", role="assistant"), ChatMessage(content="How are you?", role="human"), ChatMessage(content="duck", role="other"), ] + messages = template.format_messages(question="How are you?", quack="duck") + assert messages == expected + + messages = await template.aformat_messages(question="How are you?", quack="duck") + assert messages == expected + @pytest.mark.parametrize( "args,expected", @@ -385,7 +399,7 @@ def test_chat_message_partial() -> None: assert template2.format(input="hello") == get_buffer_string(expected) -def test_chat_tmpl_from_messages_multipart_text() -> None: +async def test_chat_tmpl_from_messages_multipart_text() -> None: template = ChatPromptTemplate.from_messages( [ ("system", "You are an AI assistant named {name}."), @@ -398,7 +412,6 @@ def test_chat_tmpl_from_messages_multipart_text() -> None: ), ] ) - messages = template.format_messages(name="R2D2") expected = [ SystemMessage(content="You are an AI assistant named R2D2."), HumanMessage( @@ -408,10 +421,14 @@ def test_chat_tmpl_from_messages_multipart_text() -> None: ] ), ] + messages = template.format_messages(name="R2D2") + assert messages == expected + + messages = await template.aformat_messages(name="R2D2") assert messages == expected -def test_chat_tmpl_from_messages_multipart_text_with_template() -> None: +async def test_chat_tmpl_from_messages_multipart_text_with_template() -> None: template = ChatPromptTemplate.from_messages( [ ("system", "You are an AI assistant named {name}."), @@ -424,7 +441,6 @@ def test_chat_tmpl_from_messages_multipart_text_with_template() -> None: ), ] ) - messages = template.format_messages(name="R2D2", object_name="image") expected = [ SystemMessage(content="You are an AI assistant named R2D2."), HumanMessage( @@ -434,10 +450,14 @@ def test_chat_tmpl_from_messages_multipart_text_with_template() -> None: ] ), ] + messages = template.format_messages(name="R2D2", object_name="image") + assert messages == expected + + messages = await template.aformat_messages(name="R2D2", object_name="image") assert messages == expected -def test_chat_tmpl_from_messages_multipart_image() -> None: +async def test_chat_tmpl_from_messages_multipart_image() -> None: base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA" other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA" template = ChatPromptTemplate.from_messages( @@ -472,9 +492,6 @@ def test_chat_tmpl_from_messages_multipart_image() -> None: ), ] ) - messages = template.format_messages( - name="R2D2", my_image=base64_image, my_other_image=other_base64_image - ) expected = [ SystemMessage(content="You are an AI assistant named R2D2."), HumanMessage( @@ -512,6 +529,14 @@ def test_chat_tmpl_from_messages_multipart_image() -> None: ] ), ] + messages = template.format_messages( + name="R2D2", my_image=base64_image, my_other_image=other_base64_image + ) + assert messages == expected + + messages = await template.aformat_messages( + name="R2D2", my_image=base64_image, my_other_image=other_base64_image + ) assert messages == expected @@ -566,14 +591,20 @@ def test_chat_prompt_message_placeholder_tuple() -> None: assert optional_prompt.format_messages() == [] -def test_messages_prompt_accepts_list() -> None: +async def test_messages_prompt_accepts_list() -> None: prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")]) value = prompt.invoke([("user", "Hi there")]) # type: ignore assert value.to_messages() == [HumanMessage(content="Hi there")] + value = await prompt.ainvoke([("user", "Hi there")]) # type: ignore + assert value.to_messages() == [HumanMessage(content="Hi there")] + # Assert still raises a nice error prompt = ChatPromptTemplate.from_messages( [("system", "You are a {foo}"), MessagesPlaceholder("history")] ) with pytest.raises(TypeError): prompt.invoke([("user", "Hi there")]) # type: ignore + + with pytest.raises(TypeError): + await prompt.ainvoke([("user", "Hi there")]) # type: ignore