core[minor]: Implement aformat for FewShotPromptWithTemplates (#20039)

pull/20229/head
Christophe Bornet 6 months ago committed by GitHub
parent 855ba46f80
commit 19001e6cb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -101,6 +101,14 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
else:
raise ValueError
async def _aget_examples(self, **kwargs: Any) -> List[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return await self.example_selector.aselect_examples(kwargs)
else:
raise ValueError
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
@ -149,6 +157,42 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
async def aformat(self, **kwargs: Any) -> str:
kwargs = self._merge_partial_and_user_variables(**kwargs)
# Get the examples to use.
examples = await self._aget_examples(**kwargs)
# Format the examples.
example_strings = [
# We can use the sync method here as PromptTemplate doesn't block
self.example_prompt.format(**example)
for example in examples
]
# Create the overall prefix.
if self.prefix is None:
prefix = ""
else:
prefix_kwargs = {
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
}
for k in prefix_kwargs.keys():
kwargs.pop(k)
prefix = await self.prefix.aformat(**prefix_kwargs)
# Create the overall suffix
suffix_kwargs = {
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
}
for k in suffix_kwargs.keys():
kwargs.pop(k)
suffix = await self.suffix.aformat(
**suffix_kwargs,
)
pieces = [prefix, *example_strings, suffix]
template = self.example_separator.join([piece for piece in pieces if piece])
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""

@ -10,7 +10,7 @@ EXAMPLE_PROMPT = PromptTemplate(
)
def test_prompttemplate_prefix_suffix() -> None:
async def test_prompttemplate_prefix_suffix() -> None:
"""Test that few shot works when prefix and suffix are PromptTemplates."""
prefix = PromptTemplate(
input_variables=["content"], template="This is a test about {content}."
@ -32,13 +32,15 @@ def test_prompttemplate_prefix_suffix() -> None:
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
)
output = prompt.format(content="animals", new_content="party")
expected_output = (
"This is a test about animals.\n"
"foo: bar\n"
"baz: foo\n"
"Now you try to talk about party."
)
output = prompt.format(content="animals", new_content="party")
assert output == expected_output
output = await prompt.aformat(content="animals", new_content="party")
assert output == expected_output

Loading…
Cancel
Save