diff --git a/langchain/__init__.py b/langchain/__init__.py index d7680041..d604c923 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -17,7 +17,7 @@ from langchain.chains import ( from langchain.docstore import Wikipedia from langchain.faiss import FAISS from langchain.llms import Cohere, HuggingFaceHub, OpenAI -from langchain.prompt import Prompt +from langchain.prompt import BasePrompt, Prompt from langchain.sql_database import SQLDatabase __all__ = [ @@ -28,6 +28,7 @@ __all__ = [ "SerpAPIChain", "Cohere", "OpenAI", + "BasePrompt", "Prompt", "ReActChain", "Wikipedia", diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index d1def8cb..9dadf9ef 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.llms.base import LLM -from langchain.prompt import Prompt +from langchain.prompt import BasePrompt class LLMChain(Chain, BaseModel): @@ -20,7 +20,7 @@ class LLMChain(Chain, BaseModel): llm = LLMChain(llm=OpenAI(), prompt=prompt) """ - prompt: Prompt + prompt: BasePrompt """Prompt object to use.""" llm: LLM """LLM wrapper to use.""" diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index b13e6671..5ba2a1dc 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.llms.base import LLM -from langchain.prompt import Prompt +from langchain.prompt import BasePrompt from langchain.text_splitter import TextSplitter @@ -29,7 +29,7 @@ class MapReduceChain(Chain, BaseModel): @classmethod def from_params( - cls, llm: LLM, prompt: Prompt, text_splitter: TextSplitter + cls, llm: LLM, prompt: BasePrompt, text_splitter: TextSplitter ) -> "MapReduceChain": """Construct a map-reduce chain that uses the chain for map and reduce.""" llm_chain = LLMChain(llm=llm, prompt=prompt) diff --git a/langchain/prompt.py b/langchain/prompt.py index a3512ed8..bec97428 100644 --- a/langchain/prompt.py +++ b/langchain/prompt.py @@ -1,4 +1,5 @@ """Prompt schema definition.""" +from abc import ABC, abstractmethod from typing import Any, Dict, List from pydantic import BaseModel, Extra, root_validator @@ -10,7 +11,31 @@ _FORMATTER_MAPPING = { } -class Prompt(BaseModel): +class BasePrompt(ABC): + """Base prompt should expose the format method, returning a prompt.""" + + input_variables: List[str] + """A list of the names of the variables the prompt template expects.""" + + @abstractmethod + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + + +class Prompt(BaseModel, BasePrompt): """Schema to represent a prompt for an LLM. Example: