forked from Archives/langchain
Add BasePrompt as abstract base class (#60)
This commit is contained in:
parent
8f907161e3
commit
4bbaa9b2d0
@ -17,7 +17,7 @@ from langchain.chains import (
|
|||||||
from langchain.docstore import Wikipedia
|
from langchain.docstore import Wikipedia
|
||||||
from langchain.faiss import FAISS
|
from langchain.faiss import FAISS
|
||||||
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
|
||||||
from langchain.prompt import Prompt
|
from langchain.prompt import BasePrompt, Prompt
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -28,6 +28,7 @@ __all__ = [
|
|||||||
"SerpAPIChain",
|
"SerpAPIChain",
|
||||||
"Cohere",
|
"Cohere",
|
||||||
"OpenAI",
|
"OpenAI",
|
||||||
|
"BasePrompt",
|
||||||
"Prompt",
|
"Prompt",
|
||||||
"ReActChain",
|
"ReActChain",
|
||||||
"Wikipedia",
|
"Wikipedia",
|
||||||
|
@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra
|
|||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.prompt import Prompt
|
from langchain.prompt import BasePrompt
|
||||||
|
|
||||||
|
|
||||||
class LLMChain(Chain, BaseModel):
|
class LLMChain(Chain, BaseModel):
|
||||||
@ -20,7 +20,7 @@ class LLMChain(Chain, BaseModel):
|
|||||||
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt: Prompt
|
prompt: BasePrompt
|
||||||
"""Prompt object to use."""
|
"""Prompt object to use."""
|
||||||
llm: LLM
|
llm: LLM
|
||||||
"""LLM wrapper to use."""
|
"""LLM wrapper to use."""
|
||||||
|
@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.llms.base import LLM
|
from langchain.llms.base import LLM
|
||||||
from langchain.prompt import Prompt
|
from langchain.prompt import BasePrompt
|
||||||
from langchain.text_splitter import TextSplitter
|
from langchain.text_splitter import TextSplitter
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ class MapReduceChain(Chain, BaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_params(
|
def from_params(
|
||||||
cls, llm: LLM, prompt: Prompt, text_splitter: TextSplitter
|
cls, llm: LLM, prompt: BasePrompt, text_splitter: TextSplitter
|
||||||
) -> "MapReduceChain":
|
) -> "MapReduceChain":
|
||||||
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Prompt schema definition."""
|
"""Prompt schema definition."""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
@ -10,7 +11,31 @@ _FORMATTER_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Prompt(BaseModel):
|
class BasePrompt(ABC):
|
||||||
|
"""Base prompt should expose the format method, returning a prompt."""
|
||||||
|
|
||||||
|
input_variables: List[str]
|
||||||
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format(self, **kwargs: Any) -> str:
|
||||||
|
"""Format the prompt with the inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Any arguments to be passed to the prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt.format(variable1="foo")
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Prompt(BaseModel, BasePrompt):
|
||||||
"""Schema to represent a prompt for an LLM.
|
"""Schema to represent a prompt for an LLM.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
Loading…
Reference in New Issue
Block a user