You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/prompts/few_shot.py

128 lines
4.3 KiB
Python

"""Prompt template that contains few shot examples."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
BasePromptTemplate,
check_valid_template,
)
from langchain.prompts.example_selector.base import BaseExampleSelector
from langchain.prompts.prompt import PromptTemplate
class FewShotPromptTemplate(BasePromptTemplate, BaseModel):
"""Prompt template that contains few shot examples."""
examples: Optional[List[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
example_selector: Optional[BaseExampleSelector] = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""
example_prompt: PromptTemplate
"""PromptTemplate used to format an individual example."""
suffix: str
"""A prompt template string to put after the examples."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
example_separator: str = "\n\n"
"""String separator used to join the prefix, the examples, and suffix."""
prefix: str = ""
"""A prompt template string to put before the examples."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True
"""Whether or not to try validating the template."""
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
if examples and example_selector:
raise ValueError(
"Only one of 'examples' and 'example_selector' should be provided"
)
if examples is None and example_selector is None:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)
return values
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix and input variables are consistent."""
if values["validate_template"]:
check_valid_template(
values["prefix"] + values["suffix"],
values["template_format"],
values["input_variables"] + list(values["partial_variables"]),
)
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def _get_examples(self, **kwargs: Any) -> List[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
# Get the examples to use.
examples = self._get_examples(**kwargs)
# Format the examples.
example_strings = [
self.example_prompt.format(**example) for example in examples
]
# Create the overall template.
pieces = [self.prefix, *example_strings, self.suffix]
template = self.example_separator.join([piece for piece in pieces if piece])
# Format the template with the input variables.
return DEFAULT_FORMATTER_MAPPING[self.template_format](template, **kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "few_shot"
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the prompt."""
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().dict(**kwargs)