diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index c8d9c0f9..08b6ceca 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -1,6 +1,7 @@ """BasePrompt schema definition.""" from abc import ABC, abstractmethod -from typing import Any, List +from typing import Any, List, Union, Dict +from pydantic import BaseModel, Field from langchain.formatting import formatter @@ -41,11 +42,14 @@ class DefaultParser(OutputParser): """Parse the output of an LLM call.""" return text -class BasePromptTemplate(ABC): + +class BasePromptTemplate(BaseModel, ABC): """Base prompt should expose the format method, returning a prompt.""" input_variables: List[str] """A list of the names of the variables the prompt template expects.""" + output_parser: OutputParser = Field(default_factory=DefaultParser) + """How to parse the output of calling an LLM on this formatted prompt.""" @abstractmethod def format(self, **kwargs: Any) -> str: