Merge pull request #20038

* Implement aformat_messages for ChatMessagePromptTemplate
This commit is contained in:
Christophe Bornet 2024-04-05 16:25:27 +02:00 committed by GitHub
parent ebd24bb5d6
commit 927793d088
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -297,6 +297,17 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
Formatted message. Formatted message.
""" """
async def aformat(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
return self.format(**kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. """Format messages from kwargs.
@ -308,6 +319,9 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
""" """
return [self.format(**kwargs)] return [self.format(**kwargs)]
async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:
return [await self.aformat(**kwargs)]
@property @property
def input_variables(self) -> List[str]: def input_variables(self) -> List[str]:
""" """
@ -350,6 +364,12 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
content=text, role=self.role, additional_kwargs=self.additional_kwargs content=text, role=self.role, additional_kwargs=self.additional_kwargs
) )
async def aformat(self, **kwargs: Any) -> BaseMessage:
text = await self.prompt.aformat(**kwargs)
return ChatMessage(
content=text, role=self.role, additional_kwargs=self.additional_kwargs
)
_StringImageMessagePromptTemplateT = TypeVar( _StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate" "_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"