Add BasePrompt as abstract base class (#60)

This commit is contained in:
Samantha Whitmore 2022-11-04 08:42:45 -07:00 committed by GitHub
parent 8f907161e3
commit 4bbaa9b2d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 6 deletions

View File

@ -17,7 +17,7 @@ from langchain.chains import (
from langchain.docstore import Wikipedia
from langchain.faiss import FAISS
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.prompt import Prompt
from langchain.prompt import BasePrompt, Prompt
from langchain.sql_database import SQLDatabase
__all__ = [
@ -28,6 +28,7 @@ __all__ = [
"SerpAPIChain",
"Cohere",
"OpenAI",
"BasePrompt",
"Prompt",
"ReActChain",
"Wikipedia",

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.llms.base import LLM
from langchain.prompt import Prompt
from langchain.prompt import BasePrompt
class LLMChain(Chain, BaseModel):
@ -20,7 +20,7 @@ class LLMChain(Chain, BaseModel):
llm = LLMChain(llm=OpenAI(), prompt=prompt)
"""
prompt: Prompt
prompt: BasePrompt
"""Prompt object to use."""
llm: LLM
"""LLM wrapper to use."""

View File

@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM
from langchain.prompt import Prompt
from langchain.prompt import BasePrompt
from langchain.text_splitter import TextSplitter
@ -29,7 +29,7 @@ class MapReduceChain(Chain, BaseModel):
@classmethod
def from_params(
cls, llm: LLM, prompt: Prompt, text_splitter: TextSplitter
cls, llm: LLM, prompt: BasePrompt, text_splitter: TextSplitter
) -> "MapReduceChain":
"""Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt)

View File

@ -1,4 +1,5 @@
"""Prompt schema definition."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from pydantic import BaseModel, Extra, root_validator
@ -10,7 +11,31 @@ _FORMATTER_MAPPING = {
}
class Prompt(BaseModel):
class BasePrompt(ABC):
"""Base prompt should expose the format method, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
@abstractmethod
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
class Prompt(BaseModel, BasePrompt):
"""Schema to represent a prompt for an LLM.
Example: