From 4a05b7f772af727baed89e33eb1196cb5d45af58 Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Tue, 18 Jul 2023 07:58:22 -0700 Subject: [PATCH] docstrings `prompts` (#7844) Added missed docstrings in `prompts` @baskaryan --- langchain/prompts/base.py | 3 + langchain/prompts/chat.py | 166 ++++++++++++++++++- langchain/prompts/few_shot.py | 2 +- langchain/prompts/few_shot_with_templates.py | 2 +- langchain/prompts/loading.py | 8 +- langchain/prompts/pipeline.py | 4 +- 6 files changed, 173 insertions(+), 12 deletions(-) diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index e527cd8276..0e39316488 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -95,7 +95,10 @@ def check_valid_template( class StringPromptValue(PromptValue): + """String prompt value.""" + text: str + """Prompt text.""" def to_string(self) -> str: """Return prompt as string.""" diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index 0264b8596b..4e92d8d43d 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -25,27 +25,53 @@ from langchain.schema.messages import ( class BaseMessagePromptTemplate(Serializable, ABC): + """Base class for message prompt templates.""" + @property def lc_serializable(self) -> bool: + """Whether this object should be serialized. + + Returns: + Whether this object should be serialized. + """ return True @abstractmethod def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """To messages.""" + """Format messages from kwargs. Should return a list of BaseMessages. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + List of BaseMessages. + """ @property @abstractmethod def input_variables(self) -> List[str]: - """Input variables for this prompt template.""" + """Input variables for this prompt template. + + Returns: + List of input variables. + """ class MessagesPlaceholder(BaseMessagePromptTemplate): """Prompt template that assumes variable is already list of messages.""" variable_name: str + """Name of variable to use as messages.""" def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """To a BaseMessage.""" + """To a BaseMessage. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + List of BaseMessage. + """ value = kwargs[self.variable_name] if not isinstance(value, list): raise ValueError( @@ -62,18 +88,27 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): @property def input_variables(self) -> List[str]: - """Input variables for this prompt template.""" + """Input variables for this prompt template. + + Returns: + List of input variable names. + """ return [self.variable_name] MessagePromptTemplateT = TypeVar( "MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate" ) +"""Type variable for message prompt templates.""" class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): + """Base class for message prompt templates that use a string prompt template.""" + prompt: StringPromptTemplate + """String prompt template.""" additional_kwargs: dict = Field(default_factory=dict) + """Additional keyword arguments to pass to the prompt template.""" @classmethod def from_template( @@ -82,6 +117,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): template_format: str = "f-string", **kwargs: Any, ) -> MessagePromptTemplateT: + """Create a class from a string template. + + Args: + template: a template. + template_format: format of the template. + **kwargs: keyword arguments to pass to the constructor. + + Returns: + A new instance of this class. + """ prompt = PromptTemplate.from_template(template, template_format=template_format) return cls(prompt=prompt, **kwargs) @@ -92,6 +137,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): input_variables: List[str], **kwargs: Any, ) -> MessagePromptTemplateT: + """Create a class from a template file. + + Args: + template_file: path to a template file. String or Path. + input_variables: list of input variables. + **kwargs: keyword arguments to pass to the constructor. + + Returns: + A new instance of this class. + """ prompt = PromptTemplate.from_file(template_file, input_variables) return cls(prompt=prompt, **kwargs) @@ -100,15 +155,32 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): """To a BaseMessage.""" def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format messages from kwargs. Should return a list of BaseMessages. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + List of BaseMessages. + """ return [self.format(**kwargs)] @property def input_variables(self) -> List[str]: + """ + Input variables for this prompt template. + + Returns: + List of input variable names. + """ return self.prompt.input_variables class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): + """Chat message prompt template.""" + role: str + """Role of the message.""" def format(self, **kwargs: Any) -> BaseMessage: text = self.prompt.format(**kwargs) @@ -118,40 +190,61 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate): + """Human message prompt template. This is a message that is sent to the user.""" + def format(self, **kwargs: Any) -> BaseMessage: text = self.prompt.format(**kwargs) return HumanMessage(content=text, additional_kwargs=self.additional_kwargs) class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): + """AI message prompt template. This is a message that is not sent to the user.""" + def format(self, **kwargs: Any) -> BaseMessage: text = self.prompt.format(**kwargs) return AIMessage(content=text, additional_kwargs=self.additional_kwargs) class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate): + """System message prompt template. + This is a message that is not sent to the user. + """ + def format(self, **kwargs: Any) -> BaseMessage: text = self.prompt.format(**kwargs) return SystemMessage(content=text, additional_kwargs=self.additional_kwargs) class ChatPromptValue(PromptValue): + """Chat prompt value.""" + messages: List[BaseMessage] + """List of messages.""" def to_string(self) -> str: """Return prompt as string.""" return get_buffer_string(self.messages) def to_messages(self) -> List[BaseMessage]: - """Return prompt as messages.""" + """Return prompt as a list of messages.""" return self.messages class BaseChatPromptTemplate(BasePromptTemplate, ABC): + """Base class for chat prompt templates.""" + def format(self, **kwargs: Any) -> str: return self.format_prompt(**kwargs).to_string() def format_prompt(self, **kwargs: Any) -> PromptValue: + """ + Format prompt. Should return a PromptValue. + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + PromptValue. + """ messages = self.format_messages(**kwargs) return ChatPromptValue(messages=messages) @@ -161,11 +254,25 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): class ChatPromptTemplate(BaseChatPromptTemplate, ABC): + """Chat prompt template. This is a prompt that is sent to the user.""" + input_variables: List[str] + """List of input variables.""" messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] + """List of messages.""" @root_validator(pre=True) def validate_input_variables(cls, values: dict) -> dict: + """ + Validate input variables. If input_variables is not set, it will be set to + the union of all input variables in the messages. + + Args: + values: values to validate. + + Returns: + Validated values. + """ messages = values["messages"] input_vars = set() for message in messages: @@ -186,6 +293,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): @classmethod def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate: + """Create a class from a template. + + Args: + template: template string. + **kwargs: keyword arguments to pass to the constructor. + + Returns: + A new instance of this class. + """ prompt_template = PromptTemplate.from_template(template, **kwargs) message = HumanMessagePromptTemplate(prompt=prompt_template) return cls.from_messages([message]) @@ -194,6 +310,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): def from_role_strings( cls, string_messages: List[Tuple[str, str]] ) -> ChatPromptTemplate: + """Create a class from a list of (role, template) tuples. + + Args: + string_messages: list of (role, template) tuples. + + Returns: + A new instance of this class. + """ messages = [ ChatMessagePromptTemplate( prompt=PromptTemplate.from_template(template), role=role @@ -206,6 +330,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): def from_strings( cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]] ) -> ChatPromptTemplate: + """Create a class from a list of (role, template) tuples. + + Args: + string_messages: list of (role, template) tuples. + + Returns: + A new instance of this class. + """ messages = [ role(prompt=PromptTemplate.from_template(template)) for role, template in string_messages @@ -216,6 +348,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): def from_messages( cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]] ) -> ChatPromptTemplate: + """ + Create a class from a list of messages. + + Args: + messages: list of messages. + + Returns: + A new instance of this class. + """ input_vars = set() for message in messages: if isinstance(message, BaseMessagePromptTemplate): @@ -226,6 +367,15 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): return self.format_prompt(**kwargs).to_string() def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """ + Format kwargs into a list of messages. + + Args: + **kwargs: keyword arguments to use for formatting. + + Returns: + List of messages. + """ kwargs = self._merge_partial_and_user_variables(**kwargs) result = [] for message_template in self.messages: @@ -251,4 +401,10 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): return "chat" def save(self, file_path: Union[Path, str]) -> None: + """ + Save prompt to file. + + Args: + file_path: path to file. + """ raise NotImplementedError diff --git a/langchain/prompts/few_shot.py b/langchain/prompts/few_shot.py index 9012295385..d89de95c26 100644 --- a/langchain/prompts/few_shot.py +++ b/langchain/prompts/few_shot.py @@ -67,7 +67,7 @@ class FewShotPromptTemplate(StringPromptTemplate): @root_validator() def template_is_valid(cls, values: Dict) -> Dict: - """Check that prefix, suffix and input variables are consistent.""" + """Check that prefix, suffix, and input variables are consistent.""" if values["validate_template"]: check_valid_template( values["prefix"] + values["suffix"], diff --git a/langchain/prompts/few_shot_with_templates.py b/langchain/prompts/few_shot_with_templates.py index c305f17182..5e5330cf94 100644 --- a/langchain/prompts/few_shot_with_templates.py +++ b/langchain/prompts/few_shot_with_templates.py @@ -59,7 +59,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): @root_validator() def template_is_valid(cls, values: Dict) -> Dict: - """Check that prefix, suffix and input variables are consistent.""" + """Check that prefix, suffix, and input variables are consistent.""" if values["validate_template"]: input_variables = values["input_variables"] expected_input_variables = set(values["suffix"].input_variables) diff --git a/langchain/prompts/loading.py b/langchain/prompts/loading.py index cc7507cc68..86937b65dc 100644 --- a/langchain/prompts/loading.py +++ b/langchain/prompts/loading.py @@ -1,4 +1,4 @@ -"""Load prompts from disk.""" +"""Load prompts.""" import importlib import json import logging @@ -31,7 +31,7 @@ def load_prompt_from_config(config: dict) -> BasePromptTemplate: def _load_template(var_name: str, config: dict) -> dict: - """Load template from disk if applicable.""" + """Load template from the path if applicable.""" # Check if template_path exists in config. if f"{var_name}_path" in config: # If it does, make sure template variable doesn't also exist. @@ -88,7 +88,7 @@ def _load_output_parser(config: dict) -> dict: def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate: - """Load the few shot prompt from the config.""" + """Load the "few shot" prompt from the config.""" # Load the suffix and prefix templates. config = _load_template("suffix", config) config = _load_template("prefix", config) @@ -128,7 +128,7 @@ def load_prompt(path: Union[str, Path]) -> BasePromptTemplate: def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: """Load prompt from file.""" - # Convert file to Path object. + # Convert file to a Path object. if isinstance(file, str): file_path = Path(file) else: diff --git a/langchain/prompts/pipeline.py b/langchain/prompts/pipeline.py index 41e34b17ed..28364766a7 100644 --- a/langchain/prompts/pipeline.py +++ b/langchain/prompts/pipeline.py @@ -11,7 +11,7 @@ def _get_inputs(inputs: dict, input_variables: List[str]) -> dict: class PipelinePromptTemplate(BasePromptTemplate): - """A prompt template for composing multiple prompts together. + """A prompt template for composing multiple prompt templates together. This can be useful when you want to reuse parts of prompts. A PipelinePrompt consists of two main parts: @@ -24,7 +24,9 @@ class PipelinePromptTemplate(BasePromptTemplate): """ final_prompt: BasePromptTemplate + """The final prompt that is returned.""" pipeline_prompts: List[Tuple[str, BasePromptTemplate]] + """A list of tuples, consisting of a string (`name`) and a Prompt Template.""" @root_validator(pre=True) def get_input_variables(cls, values: Dict) -> Dict: