docstrings prompts (#7844)

Added missed docstrings in `prompts`
@baskaryan
This commit is contained in:
Leonid Ganeline 2023-07-18 07:58:22 -07:00 committed by GitHub
parent dda11d2a05
commit 4a05b7f772
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 173 additions and 12 deletions

View File

@ -95,7 +95,10 @@ def check_valid_template(
class StringPromptValue(PromptValue): class StringPromptValue(PromptValue):
"""String prompt value."""
text: str text: str
"""Prompt text."""
def to_string(self) -> str: def to_string(self) -> str:
"""Return prompt as string.""" """Return prompt as string."""

View File

@ -25,27 +25,53 @@ from langchain.schema.messages import (
class BaseMessagePromptTemplate(Serializable, ABC): class BaseMessagePromptTemplate(Serializable, ABC):
"""Base class for message prompt templates."""
@property @property
def lc_serializable(self) -> bool: def lc_serializable(self) -> bool:
"""Whether this object should be serialized.
Returns:
Whether this object should be serialized.
"""
return True return True
@abstractmethod @abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To messages.""" """Format messages from kwargs. Should return a list of BaseMessages.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
@property @property
@abstractmethod @abstractmethod
def input_variables(self) -> List[str]: def input_variables(self) -> List[str]:
"""Input variables for this prompt template.""" """Input variables for this prompt template.
Returns:
List of input variables.
"""
class MessagesPlaceholder(BaseMessagePromptTemplate): class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages.""" """Prompt template that assumes variable is already list of messages."""
variable_name: str variable_name: str
"""Name of variable to use as messages."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To a BaseMessage.""" """To a BaseMessage.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessage.
"""
value = kwargs[self.variable_name] value = kwargs[self.variable_name]
if not isinstance(value, list): if not isinstance(value, list):
raise ValueError( raise ValueError(
@ -62,18 +88,27 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
@property @property
def input_variables(self) -> List[str]: def input_variables(self) -> List[str]:
"""Input variables for this prompt template.""" """Input variables for this prompt template.
Returns:
List of input variable names.
"""
return [self.variable_name] return [self.variable_name]
MessagePromptTemplateT = TypeVar( MessagePromptTemplateT = TypeVar(
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate" "MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
) )
"""Type variable for message prompt templates."""
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""Base class for message prompt templates that use a string prompt template."""
prompt: StringPromptTemplate prompt: StringPromptTemplate
"""String prompt template."""
additional_kwargs: dict = Field(default_factory=dict) additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
@classmethod @classmethod
def from_template( def from_template(
@ -82,6 +117,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
template_format: str = "f-string", template_format: str = "f-string",
**kwargs: Any, **kwargs: Any,
) -> MessagePromptTemplateT: ) -> MessagePromptTemplateT:
"""Create a class from a string template.
Args:
template: a template.
template_format: format of the template.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt = PromptTemplate.from_template(template, template_format=template_format) prompt = PromptTemplate.from_template(template, template_format=template_format)
return cls(prompt=prompt, **kwargs) return cls(prompt=prompt, **kwargs)
@ -92,6 +137,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
input_variables: List[str], input_variables: List[str],
**kwargs: Any, **kwargs: Any,
) -> MessagePromptTemplateT: ) -> MessagePromptTemplateT:
"""Create a class from a template file.
Args:
template_file: path to a template file. String or Path.
input_variables: list of input variables.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt = PromptTemplate.from_file(template_file, input_variables) prompt = PromptTemplate.from_file(template_file, input_variables)
return cls(prompt=prompt, **kwargs) return cls(prompt=prompt, **kwargs)
@ -100,15 +155,32 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""To a BaseMessage.""" """To a BaseMessage."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return [self.format(**kwargs)] return [self.format(**kwargs)]
@property @property
def input_variables(self) -> List[str]: def input_variables(self) -> List[str]:
"""
Input variables for this prompt template.
Returns:
List of input variable names.
"""
return self.prompt.input_variables return self.prompt.input_variables
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Chat message prompt template."""
role: str role: str
"""Role of the message."""
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs) text = self.prompt.format(**kwargs)
@ -118,40 +190,61 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate): class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Human message prompt template. This is a message that is sent to the user."""
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs) text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs) return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""AI message prompt template. This is a message that is not sent to the user."""
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs) text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs) return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate): class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs) text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs) return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
class ChatPromptValue(PromptValue): class ChatPromptValue(PromptValue):
"""Chat prompt value."""
messages: List[BaseMessage] messages: List[BaseMessage]
"""List of messages."""
def to_string(self) -> str: def to_string(self) -> str:
"""Return prompt as string.""" """Return prompt as string."""
return get_buffer_string(self.messages) return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]: def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages.""" """Return prompt as a list of messages."""
return self.messages return self.messages
class BaseChatPromptTemplate(BasePromptTemplate, ABC): class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
def format(self, **kwargs: Any) -> str: def format(self, **kwargs: Any) -> str:
return self.format_prompt(**kwargs).to_string() return self.format_prompt(**kwargs).to_string()
def format_prompt(self, **kwargs: Any) -> PromptValue: def format_prompt(self, **kwargs: Any) -> PromptValue:
"""
Format prompt. Should return a PromptValue.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
PromptValue.
"""
messages = self.format_messages(**kwargs) messages = self.format_messages(**kwargs)
return ChatPromptValue(messages=messages) return ChatPromptValue(messages=messages)
@ -161,11 +254,25 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
class ChatPromptTemplate(BaseChatPromptTemplate, ABC): class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
"""Chat prompt template. This is a prompt that is sent to the user."""
input_variables: List[str] input_variables: List[str]
"""List of input variables."""
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
"""List of messages."""
@root_validator(pre=True) @root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict: def validate_input_variables(cls, values: dict) -> dict:
"""
Validate input variables. If input_variables is not set, it will be set to
the union of all input variables in the messages.
Args:
values: values to validate.
Returns:
Validated values.
"""
messages = values["messages"] messages = values["messages"]
input_vars = set() input_vars = set()
for message in messages: for message in messages:
@ -186,6 +293,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@classmethod @classmethod
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate: def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
"""Create a class from a template.
Args:
template: template string.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt_template = PromptTemplate.from_template(template, **kwargs) prompt_template = PromptTemplate.from_template(template, **kwargs)
message = HumanMessagePromptTemplate(prompt=prompt_template) message = HumanMessagePromptTemplate(prompt=prompt_template)
return cls.from_messages([message]) return cls.from_messages([message])
@ -194,6 +310,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
def from_role_strings( def from_role_strings(
cls, string_messages: List[Tuple[str, str]] cls, string_messages: List[Tuple[str, str]]
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""Create a class from a list of (role, template) tuples.
Args:
string_messages: list of (role, template) tuples.
Returns:
A new instance of this class.
"""
messages = [ messages = [
ChatMessagePromptTemplate( ChatMessagePromptTemplate(
prompt=PromptTemplate.from_template(template), role=role prompt=PromptTemplate.from_template(template), role=role
@ -206,6 +330,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
def from_strings( def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]] cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""Create a class from a list of (role, template) tuples.
Args:
string_messages: list of (role, template) tuples.
Returns:
A new instance of this class.
"""
messages = [ messages = [
role(prompt=PromptTemplate.from_template(template)) role(prompt=PromptTemplate.from_template(template))
for role, template in string_messages for role, template in string_messages
@ -216,6 +348,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
def from_messages( def from_messages(
cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]] cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]]
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""
Create a class from a list of messages.
Args:
messages: list of messages.
Returns:
A new instance of this class.
"""
input_vars = set() input_vars = set()
for message in messages: for message in messages:
if isinstance(message, BaseMessagePromptTemplate): if isinstance(message, BaseMessagePromptTemplate):
@ -226,6 +367,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
return self.format_prompt(**kwargs).to_string() return self.format_prompt(**kwargs).to_string()
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""
Format kwargs into a list of messages.
Args:
**kwargs: keyword arguments to use for formatting.
Returns:
List of messages.
"""
kwargs = self._merge_partial_and_user_variables(**kwargs) kwargs = self._merge_partial_and_user_variables(**kwargs)
result = [] result = []
for message_template in self.messages: for message_template in self.messages:
@ -251,4 +401,10 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
return "chat" return "chat"
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
"""
Save prompt to file.
Args:
file_path: path to file.
"""
raise NotImplementedError raise NotImplementedError

View File

@ -67,7 +67,7 @@ class FewShotPromptTemplate(StringPromptTemplate):
@root_validator() @root_validator()
def template_is_valid(cls, values: Dict) -> Dict: def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix and input variables are consistent.""" """Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]: if values["validate_template"]:
check_valid_template( check_valid_template(
values["prefix"] + values["suffix"], values["prefix"] + values["suffix"],

View File

@ -59,7 +59,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
@root_validator() @root_validator()
def template_is_valid(cls, values: Dict) -> Dict: def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix and input variables are consistent.""" """Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]: if values["validate_template"]:
input_variables = values["input_variables"] input_variables = values["input_variables"]
expected_input_variables = set(values["suffix"].input_variables) expected_input_variables = set(values["suffix"].input_variables)

View File

@ -1,4 +1,4 @@
"""Load prompts from disk.""" """Load prompts."""
import importlib import importlib
import json import json
import logging import logging
@ -31,7 +31,7 @@ def load_prompt_from_config(config: dict) -> BasePromptTemplate:
def _load_template(var_name: str, config: dict) -> dict: def _load_template(var_name: str, config: dict) -> dict:
"""Load template from disk if applicable.""" """Load template from the path if applicable."""
# Check if template_path exists in config. # Check if template_path exists in config.
if f"{var_name}_path" in config: if f"{var_name}_path" in config:
# If it does, make sure template variable doesn't also exist. # If it does, make sure template variable doesn't also exist.
@ -88,7 +88,7 @@ def _load_output_parser(config: dict) -> dict:
def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate: def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
"""Load the few shot prompt from the config.""" """Load the "few shot" prompt from the config."""
# Load the suffix and prefix templates. # Load the suffix and prefix templates.
config = _load_template("suffix", config) config = _load_template("suffix", config)
config = _load_template("prefix", config) config = _load_template("prefix", config)
@ -128,7 +128,7 @@ def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
"""Load prompt from file.""" """Load prompt from file."""
# Convert file to Path object. # Convert file to a Path object.
if isinstance(file, str): if isinstance(file, str):
file_path = Path(file) file_path = Path(file)
else: else:

View File

@ -11,7 +11,7 @@ def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
class PipelinePromptTemplate(BasePromptTemplate): class PipelinePromptTemplate(BasePromptTemplate):
"""A prompt template for composing multiple prompts together. """A prompt template for composing multiple prompt templates together.
This can be useful when you want to reuse parts of prompts. This can be useful when you want to reuse parts of prompts.
A PipelinePrompt consists of two main parts: A PipelinePrompt consists of two main parts:
@ -24,7 +24,9 @@ class PipelinePromptTemplate(BasePromptTemplate):
""" """
final_prompt: BasePromptTemplate final_prompt: BasePromptTemplate
"""The final prompt that is returned."""
pipeline_prompts: List[Tuple[str, BasePromptTemplate]] pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
@root_validator(pre=True) @root_validator(pre=True)
def get_input_variables(cls, values: Dict) -> Dict: def get_input_variables(cls, values: Dict) -> Dict: