docstrings prompts (#7844)

Added missed docstrings in `prompts`
@baskaryan
This commit is contained in:
Leonid Ganeline 2023-07-18 07:58:22 -07:00 committed by GitHub
parent dda11d2a05
commit 4a05b7f772
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 173 additions and 12 deletions

View File

@ -95,7 +95,10 @@ def check_valid_template(
class StringPromptValue(PromptValue):
"""String prompt value."""
text: str
"""Prompt text."""
def to_string(self) -> str:
"""Return prompt as string."""

View File

@ -25,27 +25,53 @@ from langchain.schema.messages import (
class BaseMessagePromptTemplate(Serializable, ABC):
"""Base class for message prompt templates."""
@property
def lc_serializable(self) -> bool:
"""Whether this object should be serialized.
Returns:
Whether this object should be serialized.
"""
return True
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To messages."""
"""Format messages from kwargs. Should return a list of BaseMessages.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
@property
@abstractmethod
def input_variables(self) -> List[str]:
"""Input variables for this prompt template."""
"""Input variables for this prompt template.
Returns:
List of input variables.
"""
class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages."""
variable_name: str
"""Name of variable to use as messages."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To a BaseMessage."""
"""To a BaseMessage.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessage.
"""
value = kwargs[self.variable_name]
if not isinstance(value, list):
raise ValueError(
@ -62,18 +88,27 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
@property
def input_variables(self) -> List[str]:
"""Input variables for this prompt template."""
"""Input variables for this prompt template.
Returns:
List of input variable names.
"""
return [self.variable_name]
MessagePromptTemplateT = TypeVar(
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
)
"""Type variable for message prompt templates."""
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""Base class for message prompt templates that use a string prompt template."""
prompt: StringPromptTemplate
"""String prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
@classmethod
def from_template(
@ -82,6 +117,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
template_format: str = "f-string",
**kwargs: Any,
) -> MessagePromptTemplateT:
"""Create a class from a string template.
Args:
template: a template.
template_format: format of the template.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt = PromptTemplate.from_template(template, template_format=template_format)
return cls(prompt=prompt, **kwargs)
@ -92,6 +137,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
input_variables: List[str],
**kwargs: Any,
) -> MessagePromptTemplateT:
"""Create a class from a template file.
Args:
template_file: path to a template file. String or Path.
input_variables: list of input variables.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt = PromptTemplate.from_file(template_file, input_variables)
return cls(prompt=prompt, **kwargs)
@ -100,15 +155,32 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""To a BaseMessage."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return [self.format(**kwargs)]
@property
def input_variables(self) -> List[str]:
"""
Input variables for this prompt template.
Returns:
List of input variable names.
"""
return self.prompt.input_variables
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Chat message prompt template."""
role: str
"""Role of the message."""
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
@ -118,40 +190,61 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Human message prompt template. This is a message that is sent to the user."""
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""AI message prompt template. This is a message that is not sent to the user."""
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
class ChatPromptValue(PromptValue):
"""Chat prompt value."""
messages: List[BaseMessage]
"""List of messages."""
def to_string(self) -> str:
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
"""Return prompt as a list of messages."""
return self.messages
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
def format(self, **kwargs: Any) -> str:
return self.format_prompt(**kwargs).to_string()
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""
Format prompt. Should return a PromptValue.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
PromptValue.
"""
messages = self.format_messages(**kwargs)
return ChatPromptValue(messages=messages)
@ -161,11 +254,25 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
"""Chat prompt template. This is a prompt that is sent to the user."""
input_variables: List[str]
"""List of input variables."""
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
"""List of messages."""
@root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict:
"""
Validate input variables. If input_variables is not set, it will be set to
the union of all input variables in the messages.
Args:
values: values to validate.
Returns:
Validated values.
"""
messages = values["messages"]
input_vars = set()
for message in messages:
@ -186,6 +293,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
"""Create a class from a template.
Args:
template: template string.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt_template = PromptTemplate.from_template(template, **kwargs)
message = HumanMessagePromptTemplate(prompt=prompt_template)
return cls.from_messages([message])
@ -194,6 +310,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
def from_role_strings(
cls, string_messages: List[Tuple[str, str]]
) -> ChatPromptTemplate:
"""Create a class from a list of (role, template) tuples.
Args:
string_messages: list of (role, template) tuples.
Returns:
A new instance of this class.
"""
messages = [
ChatMessagePromptTemplate(
prompt=PromptTemplate.from_template(template), role=role
@ -206,6 +330,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
"""Create a class from a list of (role, template) tuples.
Args:
string_messages: list of (role, template) tuples.
Returns:
A new instance of this class.
"""
messages = [
role(prompt=PromptTemplate.from_template(template))
for role, template in string_messages
@ -216,6 +348,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
def from_messages(
cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]]
) -> ChatPromptTemplate:
"""
Create a class from a list of messages.
Args:
messages: list of messages.
Returns:
A new instance of this class.
"""
input_vars = set()
for message in messages:
if isinstance(message, BaseMessagePromptTemplate):
@ -226,6 +367,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
return self.format_prompt(**kwargs).to_string()
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""
Format kwargs into a list of messages.
Args:
**kwargs: keyword arguments to use for formatting.
Returns:
List of messages.
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
result = []
for message_template in self.messages:
@ -251,4 +401,10 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
return "chat"
def save(self, file_path: Union[Path, str]) -> None:
"""
Save prompt to file.
Args:
file_path: path to file.
"""
raise NotImplementedError

View File

@ -67,7 +67,7 @@ class FewShotPromptTemplate(StringPromptTemplate):
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix and input variables are consistent."""
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:
check_valid_template(
values["prefix"] + values["suffix"],

View File

@ -59,7 +59,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix and input variables are consistent."""
"""Check that prefix, suffix, and input variables are consistent."""
if values["validate_template"]:
input_variables = values["input_variables"]
expected_input_variables = set(values["suffix"].input_variables)

View File

@ -1,4 +1,4 @@
"""Load prompts from disk."""
"""Load prompts."""
import importlib
import json
import logging
@ -31,7 +31,7 @@ def load_prompt_from_config(config: dict) -> BasePromptTemplate:
def _load_template(var_name: str, config: dict) -> dict:
"""Load template from disk if applicable."""
"""Load template from the path if applicable."""
# Check if template_path exists in config.
if f"{var_name}_path" in config:
# If it does, make sure template variable doesn't also exist.
@ -88,7 +88,7 @@ def _load_output_parser(config: dict) -> dict:
def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
"""Load the few shot prompt from the config."""
"""Load the "few shot" prompt from the config."""
# Load the suffix and prefix templates.
config = _load_template("suffix", config)
config = _load_template("prefix", config)
@ -128,7 +128,7 @@ def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
"""Load prompt from file."""
# Convert file to Path object.
# Convert file to a Path object.
if isinstance(file, str):
file_path = Path(file)
else:

View File

@ -11,7 +11,7 @@ def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
class PipelinePromptTemplate(BasePromptTemplate):
"""A prompt template for composing multiple prompts together.
"""A prompt template for composing multiple prompt templates together.
This can be useful when you want to reuse parts of prompts.
A PipelinePrompt consists of two main parts:
@ -24,7 +24,9 @@ class PipelinePromptTemplate(BasePromptTemplate):
"""
final_prompt: BasePromptTemplate
"""The final prompt that is returned."""
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
@root_validator(pre=True)
def get_input_variables(cls, values: Dict) -> Dict: