Add BasePrompt as abstract base class (#60)

This commit is contained in:
Samantha Whitmore 2022-11-04 08:42:45 -07:00 committed by GitHub
parent 8f907161e3
commit 4bbaa9b2d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 6 deletions

View File

@ -17,7 +17,7 @@ from langchain.chains import (
from langchain.docstore import Wikipedia from langchain.docstore import Wikipedia
from langchain.faiss import FAISS from langchain.faiss import FAISS
from langchain.llms import Cohere, HuggingFaceHub, OpenAI from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.prompt import Prompt from langchain.prompt import BasePrompt, Prompt
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
__all__ = [ __all__ = [
@ -28,6 +28,7 @@ __all__ = [
"SerpAPIChain", "SerpAPIChain",
"Cohere", "Cohere",
"OpenAI", "OpenAI",
"BasePrompt",
"Prompt", "Prompt",
"ReActChain", "ReActChain",
"Wikipedia", "Wikipedia",

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.prompt import Prompt from langchain.prompt import BasePrompt
class LLMChain(Chain, BaseModel): class LLMChain(Chain, BaseModel):
@ -20,7 +20,7 @@ class LLMChain(Chain, BaseModel):
llm = LLMChain(llm=OpenAI(), prompt=prompt) llm = LLMChain(llm=OpenAI(), prompt=prompt)
""" """
prompt: Prompt prompt: BasePrompt
"""Prompt object to use.""" """Prompt object to use."""
llm: LLM llm: LLM
"""LLM wrapper to use.""" """LLM wrapper to use."""

View File

@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.prompt import Prompt from langchain.prompt import BasePrompt
from langchain.text_splitter import TextSplitter from langchain.text_splitter import TextSplitter
@ -29,7 +29,7 @@ class MapReduceChain(Chain, BaseModel):
@classmethod @classmethod
def from_params( def from_params(
cls, llm: LLM, prompt: Prompt, text_splitter: TextSplitter cls, llm: LLM, prompt: BasePrompt, text_splitter: TextSplitter
) -> "MapReduceChain": ) -> "MapReduceChain":
"""Construct a map-reduce chain that uses the chain for map and reduce.""" """Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt) llm_chain = LLMChain(llm=llm, prompt=prompt)

View File

@ -1,4 +1,5 @@
"""Prompt schema definition.""" """Prompt schema definition."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -10,7 +11,31 @@ _FORMATTER_MAPPING = {
} }
class Prompt(BaseModel): class BasePrompt(ABC):
"""Base prompt should expose the format method, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
@abstractmethod
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
class Prompt(BaseModel, BasePrompt):
"""Schema to represent a prompt for an LLM. """Schema to represent a prompt for an LLM.
Example: Example: