mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
docstrings prompts
(#7844)
Added missed docstrings in `prompts` @baskaryan
This commit is contained in:
parent
dda11d2a05
commit
4a05b7f772
@ -95,7 +95,10 @@ def check_valid_template(
|
||||
|
||||
|
||||
class StringPromptValue(PromptValue):
|
||||
"""String prompt value."""
|
||||
|
||||
text: str
|
||||
"""Prompt text."""
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
|
@ -25,27 +25,53 @@ from langchain.schema.messages import (
|
||||
|
||||
|
||||
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
"""Base class for message prompt templates."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Whether this object should be serialized.
|
||||
|
||||
Returns:
|
||||
Whether this object should be serialized.
|
||||
"""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""To messages."""
|
||||
"""Format messages from kwargs. Should return a list of BaseMessages.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessages.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_variables(self) -> List[str]:
|
||||
"""Input variables for this prompt template."""
|
||||
"""Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variables.
|
||||
"""
|
||||
|
||||
|
||||
class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""Prompt template that assumes variable is already list of messages."""
|
||||
|
||||
variable_name: str
|
||||
"""Name of variable to use as messages."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""To a BaseMessage."""
|
||||
"""To a BaseMessage.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessage.
|
||||
"""
|
||||
value = kwargs[self.variable_name]
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(
|
||||
@ -62,18 +88,27 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""Input variables for this prompt template."""
|
||||
"""Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
return [self.variable_name]
|
||||
|
||||
|
||||
MessagePromptTemplateT = TypeVar(
|
||||
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
|
||||
)
|
||||
"""Type variable for message prompt templates."""
|
||||
|
||||
|
||||
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
"""Base class for message prompt templates that use a string prompt template."""
|
||||
|
||||
prompt: StringPromptTemplate
|
||||
"""String prompt template."""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Additional keyword arguments to pass to the prompt template."""
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
@ -82,6 +117,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
template_format: str = "f-string",
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
"""Create a class from a string template.
|
||||
|
||||
Args:
|
||||
template: a template.
|
||||
template_format: format of the template.
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
prompt = PromptTemplate.from_template(template, template_format=template_format)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
|
||||
@ -92,6 +137,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
input_variables: List[str],
|
||||
**kwargs: Any,
|
||||
) -> MessagePromptTemplateT:
|
||||
"""Create a class from a template file.
|
||||
|
||||
Args:
|
||||
template_file: path to a template file. String or Path.
|
||||
input_variables: list of input variables.
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
prompt = PromptTemplate.from_file(template_file, input_variables)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
|
||||
@ -100,15 +155,32 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
"""To a BaseMessage."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs. Should return a list of BaseMessages.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of BaseMessages.
|
||||
"""
|
||||
return [self.format(**kwargs)]
|
||||
|
||||
@property
|
||||
def input_variables(self) -> List[str]:
|
||||
"""
|
||||
Input variables for this prompt template.
|
||||
|
||||
Returns:
|
||||
List of input variable names.
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
|
||||
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Chat message prompt template."""
|
||||
|
||||
role: str
|
||||
"""Role of the message."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
text = self.prompt.format(**kwargs)
|
||||
@ -118,40 +190,61 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
|
||||
|
||||
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message that is sent to the user."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
text = self.prompt.format(**kwargs)
|
||||
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""AI message prompt template. This is a message that is not sent to the user."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
text = self.prompt.format(**kwargs)
|
||||
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""System message prompt template.
|
||||
This is a message that is not sent to the user.
|
||||
"""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
text = self.prompt.format(**kwargs)
|
||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
"""Chat prompt value."""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
"""List of messages."""
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
"""Return prompt as a list of messages."""
|
||||
return self.messages
|
||||
|
||||
|
||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Base class for chat prompt templates."""
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
"""
|
||||
Format prompt. Should return a PromptValue.
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
PromptValue.
|
||||
"""
|
||||
messages = self.format_messages(**kwargs)
|
||||
return ChatPromptValue(messages=messages)
|
||||
|
||||
@ -161,11 +254,25 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
|
||||
|
||||
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
"""Chat prompt template. This is a prompt that is sent to the user."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""List of input variables."""
|
||||
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
"""List of messages."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_input_variables(cls, values: dict) -> dict:
|
||||
"""
|
||||
Validate input variables. If input_variables is not set, it will be set to
|
||||
the union of all input variables in the messages.
|
||||
|
||||
Args:
|
||||
values: values to validate.
|
||||
|
||||
Returns:
|
||||
Validated values.
|
||||
"""
|
||||
messages = values["messages"]
|
||||
input_vars = set()
|
||||
for message in messages:
|
||||
@ -186,6 +293,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
|
||||
@classmethod
|
||||
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
|
||||
"""Create a class from a template.
|
||||
|
||||
Args:
|
||||
template: template string.
|
||||
**kwargs: keyword arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
prompt_template = PromptTemplate.from_template(template, **kwargs)
|
||||
message = HumanMessagePromptTemplate(prompt=prompt_template)
|
||||
return cls.from_messages([message])
|
||||
@ -194,6 +310,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
def from_role_strings(
|
||||
cls, string_messages: List[Tuple[str, str]]
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a class from a list of (role, template) tuples.
|
||||
|
||||
Args:
|
||||
string_messages: list of (role, template) tuples.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
messages = [
|
||||
ChatMessagePromptTemplate(
|
||||
prompt=PromptTemplate.from_template(template), role=role
|
||||
@ -206,6 +330,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
def from_strings(
|
||||
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a class from a list of (role, template) tuples.
|
||||
|
||||
Args:
|
||||
string_messages: list of (role, template) tuples.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
messages = [
|
||||
role(prompt=PromptTemplate.from_template(template))
|
||||
for role, template in string_messages
|
||||
@ -216,6 +348,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
def from_messages(
|
||||
cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
) -> ChatPromptTemplate:
|
||||
"""
|
||||
Create a class from a list of messages.
|
||||
|
||||
Args:
|
||||
messages: list of messages.
|
||||
|
||||
Returns:
|
||||
A new instance of this class.
|
||||
"""
|
||||
input_vars = set()
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessagePromptTemplate):
|
||||
@ -226,6 +367,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""
|
||||
Format kwargs into a list of messages.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
List of messages.
|
||||
"""
|
||||
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
||||
result = []
|
||||
for message_template in self.messages:
|
||||
@ -251,4 +401,10 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
return "chat"
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""
|
||||
Save prompt to file.
|
||||
|
||||
Args:
|
||||
file_path: path to file.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -67,7 +67,7 @@ class FewShotPromptTemplate(StringPromptTemplate):
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix and input variables are consistent."""
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
check_valid_template(
|
||||
values["prefix"] + values["suffix"],
|
||||
|
@ -59,7 +59,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix and input variables are consistent."""
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
input_variables = values["input_variables"]
|
||||
expected_input_variables = set(values["suffix"].input_variables)
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Load prompts from disk."""
|
||||
"""Load prompts."""
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
@ -31,7 +31,7 @@ def load_prompt_from_config(config: dict) -> BasePromptTemplate:
|
||||
|
||||
|
||||
def _load_template(var_name: str, config: dict) -> dict:
|
||||
"""Load template from disk if applicable."""
|
||||
"""Load template from the path if applicable."""
|
||||
# Check if template_path exists in config.
|
||||
if f"{var_name}_path" in config:
|
||||
# If it does, make sure template variable doesn't also exist.
|
||||
@ -88,7 +88,7 @@ def _load_output_parser(config: dict) -> dict:
|
||||
|
||||
|
||||
def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
|
||||
"""Load the few shot prompt from the config."""
|
||||
"""Load the "few shot" prompt from the config."""
|
||||
# Load the suffix and prefix templates.
|
||||
config = _load_template("suffix", config)
|
||||
config = _load_template("prefix", config)
|
||||
@ -128,7 +128,7 @@ def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
||||
|
||||
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
||||
"""Load prompt from file."""
|
||||
# Convert file to Path object.
|
||||
# Convert file to a Path object.
|
||||
if isinstance(file, str):
|
||||
file_path = Path(file)
|
||||
else:
|
||||
|
@ -11,7 +11,7 @@ def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:
|
||||
|
||||
|
||||
class PipelinePromptTemplate(BasePromptTemplate):
|
||||
"""A prompt template for composing multiple prompts together.
|
||||
"""A prompt template for composing multiple prompt templates together.
|
||||
|
||||
This can be useful when you want to reuse parts of prompts.
|
||||
A PipelinePrompt consists of two main parts:
|
||||
@ -24,7 +24,9 @@ class PipelinePromptTemplate(BasePromptTemplate):
|
||||
"""
|
||||
|
||||
final_prompt: BasePromptTemplate
|
||||
"""The final prompt that is returned."""
|
||||
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
|
||||
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_input_variables(cls, values: Dict) -> Dict:
|
||||
|
Loading…
Reference in New Issue
Block a user