Harrison/improve cache (#368)

make it so everything goes through generate, which removes the need for
two types of caches
harrison/sequential_chain_from_prompts
Harrison Chase 1 year ago committed by GitHub
parent 8d0869c6d3
commit 3474f39e21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,7 +11,7 @@ from langchain.agents.tools import Tool
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping from langchain.input import get_color_mapping
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import AgentAction from langchain.schema import AgentAction
@ -87,7 +87,9 @@ class Agent(Chain, BaseModel, ABC):
pass pass
@classmethod @classmethod
def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> Agent: def from_llm_and_tools(
cls, llm: BaseLLM, tools: List[Tool], **kwargs: Any
) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
cls._validate_tools(tools) cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))

@ -6,7 +6,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
AGENT_TO_CLASS = { AGENT_TO_CLASS = {
"zero-shot-react-description": ZeroShotAgent, "zero-shot-react-description": ZeroShotAgent,
@ -17,7 +17,7 @@ AGENT_TO_CLASS = {
def initialize_agent( def initialize_agent(
tools: List[Tool], tools: List[Tool],
llm: LLM, llm: BaseLLM,
agent: str = "zero-shot-react-description", agent: str = "zero-shot-react-description",
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:

@ -6,7 +6,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, Tuple
from langchain.agents.agent import Agent from langchain.agents.agent import Agent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
FINAL_ANSWER_ACTION = "Final Answer: " FINAL_ANSWER_ACTION = "Final Answer: "
@ -116,7 +116,9 @@ class MRKLChain(ZeroShotAgent):
""" """
@classmethod @classmethod
def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> Agent: def from_chains(
cls, llm: BaseLLM, chains: List[ChainConfig], **kwargs: Any
) -> Agent:
"""User friendly way to initialize the MRKL chain. """User friendly way to initialize the MRKL chain.
This is intended to be an easy way to get up and running with the This is intended to be an easy way to get up and running with the

@ -11,7 +11,7 @@ from langchain.agents.tools import Tool
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.base import Docstore from langchain.docstore.base import Docstore
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
@ -123,7 +123,7 @@ class ReActChain(ReActDocstoreAgent):
react = ReAct(llm=OpenAI()) react = ReAct(llm=OpenAI())
""" """
def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any): def __init__(self, llm: BaseLLM, docstore: Docstore, **kwargs: Any):
"""Initialize with the LLM and a docstore.""" """Initialize with the LLM and a docstore."""
docstore_explorer = DocstoreExplorer(docstore) docstore_explorer = DocstoreExplorer(docstore)
tools = [ tools = [

@ -5,7 +5,7 @@ from langchain.agents.agent import Agent
from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.serpapi import SerpAPIWrapper from langchain.serpapi import SerpAPIWrapper
@ -72,7 +72,7 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent):
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain) self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
""" """
def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any): def __init__(self, llm: BaseLLM, search_chain: SerpAPIWrapper, **kwargs: Any):
"""Initialize with just an LLM and a search chain.""" """Initialize with just an LLM and a search chain."""
search_tool = Tool(name="Intermediate Answer", func=search_chain.run) search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
llm_chain = LLMChain(llm=llm, prompt=PROMPT) llm_chain = LLMChain(llm=llm, prompt=PROMPT)

@ -1,6 +1,6 @@
"""Beta Feature: base interface for cache.""" """Beta Feature: base interface for cache."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple
from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy import Column, Integer, String, create_engine, select
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from langchain.schema import Generation from langchain.schema import Generation
RETURN_VAL_TYPE = Union[List[Generation], str] RETURN_VAL_TYPE = List[Generation]
class BaseCache(ABC): class BaseCache(ABC):
@ -43,15 +43,6 @@ class InMemoryCache(BaseCache):
Base = declarative_base() Base = declarative_base()
class LLMCache(Base): # type: ignore
"""SQLite table for simple LLM cache (string only)."""
__tablename__ = "llm_cache"
prompt = Column(String, primary_key=True)
llm = Column(String, primary_key=True)
response = Column(String)
class FullLLMCache(Base): # type: ignore class FullLLMCache(Base): # type: ignore
"""SQLite table for full LLM Cache (all generations).""" """SQLite table for full LLM Cache (all generations)."""
@ -84,29 +75,16 @@ class SQLAlchemyCache(BaseCache):
generations.append(Generation(text=row[0])) generations.append(Generation(text=row[0]))
if len(generations) > 0: if len(generations) > 0:
return generations return generations
stmt = (
select(LLMCache.response)
.where(LLMCache.prompt == prompt)
.where(LLMCache.llm == llm_string)
)
with Session(self.engine) as session:
for row in session.execute(stmt):
return row[0]
return None return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Look up based on prompt and llm_string.""" """Look up based on prompt and llm_string."""
if isinstance(return_val, str): for i, generation in enumerate(return_val):
item = LLMCache(prompt=prompt, llm=llm_string, response=return_val) item = FullLLMCache(
prompt=prompt, llm=llm_string, response=generation.text, idx=i
)
with Session(self.engine) as session, session.begin(): with Session(self.engine) as session, session.begin():
session.add(item) session.add(item)
else:
for i, generation in enumerate(return_val):
item = FullLLMCache(
prompt=prompt, llm=llm_string, response=generation.text, idx=i
)
with Session(self.engine) as session, session.begin():
session.add(item)
class SQLiteCache(SQLAlchemyCache): class SQLiteCache(SQLAlchemyCache):

@ -9,7 +9,7 @@ from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import print_text from langchain.input import print_text
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.requests import RequestsWrapper from langchain.requests import RequestsWrapper
@ -81,7 +81,7 @@ class APIChain(Chain, BaseModel):
@classmethod @classmethod
def from_llm_and_api_docs( def from_llm_and_api_docs(
cls, llm: LLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any cls, llm: BaseLLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
) -> APIChain: ) -> APIChain:
"""Load chain from just an LLM and the api docs.""" """Load chain from just an LLM and the api docs."""
get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT) get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT)

@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, root_validator
from langchain.chains.base import Memory from langchain.chains.base import Memory
from langchain.chains.conversation.prompt import SUMMARY_PROMPT from langchain.chains.conversation.prompt import SUMMARY_PROMPT
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
@ -88,7 +88,7 @@ class ConversationSummaryMemory(Memory, BaseModel):
"""Conversation summarizer to memory.""" """Conversation summarizer to memory."""
buffer: str = "" buffer: str = ""
llm: LLM llm: BaseLLM
prompt: BasePromptTemplate = SUMMARY_PROMPT prompt: BasePromptTemplate = SUMMARY_PROMPT
memory_key: str = "history" #: :meta private: memory_key: str = "history" #: :meta private:

@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra
import langchain import langchain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
@ -25,7 +25,7 @@ class LLMChain(Chain, BaseModel):
prompt: BasePromptTemplate prompt: BasePromptTemplate
"""Prompt object to use.""" """Prompt object to use."""
llm: LLM llm: BaseLLM
"""LLM wrapper to use.""" """LLM wrapper to use."""
output_key: str = "text" #: :meta private: output_key: str = "text" #: :meta private:

@ -7,7 +7,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT from langchain.chains.llm_bash.prompt import PROMPT
from langchain.input import print_text from langchain.input import print_text
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.utilities.bash import BashProcess from langchain.utilities.bash import BashProcess
@ -21,7 +21,7 @@ class LLMBashChain(Chain, BaseModel):
llm_bash = LLMBashChain(llm=OpenAI()) llm_bash = LLMBashChain(llm=OpenAI())
""" """
llm: LLM llm: BaseLLM
"""LLM wrapper to use.""" """LLM wrapper to use."""
input_key: str = "question" #: :meta private: input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private: output_key: str = "answer" #: :meta private:

@ -14,7 +14,7 @@ from langchain.chains.llm_checker.prompt import (
REVISED_ANSWER_PROMPT, REVISED_ANSWER_PROMPT,
) )
from langchain.chains.sequential import SequentialChain from langchain.chains.sequential import SequentialChain
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
@ -28,7 +28,7 @@ class LLMCheckerChain(Chain, BaseModel):
checker_chain = LLMCheckerChain(llm=llm) checker_chain = LLMCheckerChain(llm=llm)
""" """
llm: LLM llm: BaseLLM
"""LLM wrapper to use.""" """LLM wrapper to use."""
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT

@ -7,7 +7,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT from langchain.chains.llm_math.prompt import PROMPT
from langchain.input import print_text from langchain.input import print_text
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.python import PythonREPL from langchain.python import PythonREPL
@ -21,7 +21,7 @@ class LLMMathChain(Chain, BaseModel):
llm_math = LLMMathChain(llm=OpenAI()) llm_math = LLMMathChain(llm=OpenAI())
""" """
llm: LLM llm: BaseLLM
"""LLM wrapper to use.""" """LLM wrapper to use."""
input_key: str = "question" #: :meta private: input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private: output_key: str = "answer" #: :meta private:

@ -15,7 +15,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.text_splitter import TextSplitter from langchain.text_splitter import TextSplitter
@ -32,7 +32,7 @@ class MapReduceChain(Chain, BaseModel):
@classmethod @classmethod
def from_params( def from_params(
cls, llm: LLM, prompt: BasePromptTemplate, text_splitter: TextSplitter cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
) -> MapReduceChain: ) -> MapReduceChain:
"""Construct a map-reduce chain that uses the chain for map and reduce.""" """Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt) llm_chain = LLMChain(llm=llm, prompt=prompt)

@ -8,7 +8,7 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.natbot.prompt import PROMPT from langchain.chains.natbot.prompt import PROMPT
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
@ -22,7 +22,7 @@ class NatBotChain(Chain, BaseModel):
natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.") natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.")
""" """
llm: LLM llm: BaseLLM
"""LLM wrapper to use.""" """LLM wrapper to use."""
objective: str objective: str
"""Objective that NatBot is tasked with completing.""" """Objective that NatBot is tasked with completing."""

@ -13,7 +13,7 @@ from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.input import print_text from langchain.input import print_text
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.python import PythonREPL from langchain.python import PythonREPL
@ -21,7 +21,7 @@ from langchain.python import PythonREPL
class PALChain(Chain, BaseModel): class PALChain(Chain, BaseModel):
"""Implements Program-Aided Language Models.""" """Implements Program-Aided Language Models."""
llm: LLM llm: BaseLLM
prompt: BasePromptTemplate prompt: BasePromptTemplate
stop: str = "\n\n" stop: str = "\n\n"
get_answer_expr: str = "print(solution())" get_answer_expr: str = "print(solution())"
@ -59,7 +59,7 @@ class PALChain(Chain, BaseModel):
return {self.output_key: res.strip()} return {self.output_key: res.strip()}
@classmethod @classmethod
def from_math_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain: def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
"""Load PAL from math prompt.""" """Load PAL from math prompt."""
return cls( return cls(
llm=llm, llm=llm,
@ -70,7 +70,7 @@ class PALChain(Chain, BaseModel):
) )
@classmethod @classmethod
def from_colored_object_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain: def from_colored_object_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
"""Load PAL from colored object prompt.""" """Load PAL from colored object prompt."""
return cls( return cls(
llm=llm, llm=llm,

@ -11,19 +11,19 @@ from langchain.chains.qa_with_sources import (
refine_prompts, refine_prompts,
stuff_prompt, stuff_prompt,
) )
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
class LoadingCallable(Protocol): class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain.""" """Interface for loading the combine documents chain."""
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain: def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain.""" """Callable to load the combine documents chain."""
def _load_stuff_chain( def _load_stuff_chain(
llm: LLM, llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "summaries", document_variable_name: str = "summaries",
**kwargs: Any, **kwargs: Any,
@ -38,7 +38,7 @@ def _load_stuff_chain(
def _load_map_reduce_chain( def _load_map_reduce_chain(
llm: LLM, llm: BaseLLM,
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT, document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
@ -72,7 +72,7 @@ def _load_map_reduce_chain(
def _load_refine_chain( def _load_refine_chain(
llm: LLM, llm: BaseLLM,
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT, document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
@ -93,7 +93,7 @@ def _load_refine_chain(
def load_qa_with_sources_chain( def load_qa_with_sources_chain(
llm: LLM, chain_type: str = "stuff", **kwargs: Any llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
) -> BaseCombineDocumentsChain: ) -> BaseCombineDocumentsChain:
"""Load question answering with sources chain. """Load question answering with sources chain.

@ -18,7 +18,7 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
QUESTION_PROMPT, QUESTION_PROMPT,
) )
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
@ -35,7 +35,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: LLM, llm: BaseLLM,
document_prompt: BasePromptTemplate = EXAMPLE_PROMPT, document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
question_prompt: BasePromptTemplate = QUESTION_PROMPT, question_prompt: BasePromptTemplate = QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = COMBINE_PROMPT, combine_prompt: BasePromptTemplate = COMBINE_PROMPT,

@ -11,19 +11,19 @@ from langchain.chains.question_answering import (
refine_prompts, refine_prompts,
stuff_prompt, stuff_prompt,
) )
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
class LoadingCallable(Protocol): class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain.""" """Interface for loading the combine documents chain."""
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain: def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain.""" """Callable to load the combine documents chain."""
def _load_stuff_chain( def _load_stuff_chain(
llm: LLM, llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "context", document_variable_name: str = "context",
**kwargs: Any, **kwargs: Any,
@ -36,7 +36,7 @@ def _load_stuff_chain(
def _load_map_reduce_chain( def _load_map_reduce_chain(
llm: LLM, llm: BaseLLM,
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
combine_document_variable_name: str = "summaries", combine_document_variable_name: str = "summaries",
@ -67,7 +67,7 @@ def _load_map_reduce_chain(
def _load_refine_chain( def _load_refine_chain(
llm: LLM, llm: BaseLLM,
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
document_variable_name: str = "context_str", document_variable_name: str = "context_str",
@ -86,7 +86,7 @@ def _load_refine_chain(
def load_qa_chain( def load_qa_chain(
llm: LLM, chain_type: str = "stuff", **kwargs: Any llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
) -> BaseCombineDocumentsChain: ) -> BaseCombineDocumentsChain:
"""Load question answering chain. """Load question answering chain.

@ -7,7 +7,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import PROMPT from langchain.chains.sql_database.prompt import PROMPT
from langchain.input import print_text from langchain.input import print_text
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
@ -22,7 +22,7 @@ class SQLDatabaseChain(Chain, BaseModel):
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db) db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
""" """
llm: LLM llm: BaseLLM
"""LLM wrapper to use.""" """LLM wrapper to use."""
database: SQLDatabase database: SQLDatabase
"""SQL Database to connect to.""" """SQL Database to connect to."""

@ -7,19 +7,19 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
class LoadingCallable(Protocol): class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain.""" """Interface for loading the combine documents chain."""
def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain: def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain.""" """Callable to load the combine documents chain."""
def _load_stuff_chain( def _load_stuff_chain(
llm: LLM, llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "text", document_variable_name: str = "text",
**kwargs: Any, **kwargs: Any,
@ -32,7 +32,7 @@ def _load_stuff_chain(
def _load_map_reduce_chain( def _load_map_reduce_chain(
llm: LLM, llm: BaseLLM,
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_document_variable_name: str = "text", combine_document_variable_name: str = "text",
@ -63,7 +63,7 @@ def _load_map_reduce_chain(
def _load_refine_chain( def _load_refine_chain(
llm: LLM, llm: BaseLLM,
question_prompt: BasePromptTemplate = refine_prompts.PROMPT, question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT, refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
document_variable_name: str = "text", document_variable_name: str = "text",
@ -82,7 +82,7 @@ def _load_refine_chain(
def load_summarize_chain( def load_summarize_chain(
llm: LLM, chain_type: str = "stuff", **kwargs: Any llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
) -> BaseCombineDocumentsChain: ) -> BaseCombineDocumentsChain:
"""Load summarizing chain. """Load summarizing chain.

@ -10,7 +10,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.vector_db_qa.prompt import PROMPT from langchain.chains.vector_db_qa.prompt import PROMPT
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
@ -84,7 +84,7 @@ class VectorDBQA(Chain, BaseModel):
@classmethod @classmethod
def from_llm( def from_llm(
cls, llm: LLM, prompt: PromptTemplate = PROMPT, **kwargs: Any cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
) -> VectorDBQA: ) -> VectorDBQA:
"""Initialize from LLM.""" """Initialize from LLM."""
llm_chain = LLMChain(llm=llm, prompt=prompt) llm_chain = LLMChain(llm=llm, prompt=prompt)

@ -2,7 +2,7 @@
from typing import List from typing import List
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
@ -10,7 +10,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example."
def generate_example( def generate_example(
examples: List[dict], llm: LLM, prompt_template: PromptTemplate examples: List[dict], llm: BaseLLM, prompt_template: PromptTemplate
) -> str: ) -> str:
"""Return another example given a list of examples for a prompt.""" """Return another example given a list of examples for a prompt."""
prompt = FewShotPromptTemplate( prompt = FewShotPromptTemplate(

@ -2,7 +2,7 @@
from typing import Dict, Type from typing import Dict, Type
from langchain.llms.ai21 import AI21 from langchain.llms.ai21 import AI21
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.llms.cohere import Cohere from langchain.llms.cohere import Cohere
from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.llms.huggingface_pipeline import HuggingFacePipeline
@ -18,7 +18,7 @@ __all__ = [
"AI21", "AI21",
] ]
type_to_cls_dict: Dict[str, Type[LLM]] = { type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"ai21": AI21, "ai21": AI21,
"cohere": Cohere, "cohere": Cohere,
"huggingface_hub": HuggingFaceHub, "huggingface_hub": HuggingFaceHub,

@ -21,7 +21,7 @@ class LLMResult(NamedTuple):
"""For arbitrary LLM provider specific output.""" """For arbitrary LLM provider specific output."""
class LLM(BaseModel, ABC): class BaseLLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string.""" """LLM wrapper should take in a prompt and return a string."""
class Config: class Config:
@ -29,16 +29,11 @@ class LLM(BaseModel, ABC):
extra = Extra.forbid extra = Extra.forbid
@abstractmethod
def _generate( def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompts."""
# TODO: add caching here.
generations = []
for prompt in prompts:
text = self(prompt, stop=stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
def generate( def generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
@ -88,28 +83,9 @@ class LLM(BaseModel, ABC):
# calculate the number of tokens in the tokenized text # calculate the number of tokens in the tokenized text
return len(tokenized_text) return len(tokenized_text)
@abstractmethod
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input."""
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Check Cache and run the LLM on the given prompt and input.""" """Check Cache and run the LLM on the given prompt and input."""
if langchain.llm_cache is None: return self.generate([prompt], stop=stop).generations[0][0].text
return self._call(prompt, stop=stop)
params = self._llm_dict()
params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()]))
if langchain.cache is not None:
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if cache_val is not None:
if isinstance(cache_val, str):
return cache_val
else:
return cache_val[0].text
return_val = self._call(prompt, stop=stop)
if langchain.cache is not None:
langchain.llm_cache.update(prompt, llm_string, return_val)
return return_val
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
@ -163,3 +139,26 @@ class LLM(BaseModel, ABC):
yaml.dump(prompt_dict, f, default_flow_style=False) yaml.dump(prompt_dict, f, default_flow_style=False)
else: else:
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or yaml")
class LLM(BaseLLM):
"""LLM class that expect subclasses to implement a simpler call method.
The purpose of this class is to expose a simpler interface for working
with LLMs, rather than expect the user to implement the full _generate method.
"""
@abstractmethod
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input."""
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# TODO: add caching here.
generations = []
for prompt in prompts:
text = self._call(prompt, stop=stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)

@ -6,10 +6,10 @@ from typing import Union
import yaml import yaml
from langchain.llms import type_to_cls_dict from langchain.llms import type_to_cls_dict
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
def load_llm_from_config(config: dict) -> LLM: def load_llm_from_config(config: dict) -> BaseLLM:
"""Load LLM from Config Dict.""" """Load LLM from Config Dict."""
if "_type" not in config: if "_type" not in config:
raise ValueError("Must specify an LLM Type in config") raise ValueError("Must specify an LLM Type in config")
@ -22,7 +22,7 @@ def load_llm_from_config(config: dict) -> LLM:
return llm_cls(**config) return llm_cls(**config)
def load_llm(file: Union[str, Path]) -> LLM: def load_llm(file: Union[str, Path]) -> BaseLLM:
"""Load LLM from file.""" """Load LLM from file."""
# Convert file to Path object. # Convert file to Path object.
if isinstance(file, str): if isinstance(file, str):

@ -4,12 +4,12 @@ from typing import Any, Dict, Generator, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import LLM, LLMResult from langchain.llms.base import BaseLLM, LLMResult
from langchain.schema import Generation from langchain.schema import Generation
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class OpenAI(LLM, BaseModel): class OpenAI(BaseLLM, BaseModel):
"""Wrapper around OpenAI large language models. """Wrapper around OpenAI large language models.
To use, you should have the ``openai`` python package installed, and the To use, you should have the ``openai`` python package installed, and the
@ -197,23 +197,6 @@ class OpenAI(LLM, BaseModel):
"""Return type of llm.""" """Return type of llm."""
return "openai" return "openai"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to OpenAI's create endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = openai("Tell me a joke.")
"""
return self.generate([prompt], stop=stop).generations[0][0].text
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
"""Calculate num tokens with tiktoken package.""" """Calculate num tokens with tiktoken package."""
# tiktoken NOT supported for Python 3.8 or below # tiktoken NOT supported for Python 3.8 or below

@ -6,7 +6,7 @@ from typing import List, Optional, Sequence
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text from langchain.input import get_color_mapping, print_text
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
@ -46,7 +46,7 @@ class ModelLaboratory:
@classmethod @classmethod
def from_llms( def from_llms(
cls, llms: List[LLM], prompt: Optional[PromptTemplate] = None cls, llms: List[BaseLLM], prompt: Optional[PromptTemplate] = None
) -> ModelLaboratory: ) -> ModelLaboratory:
"""Initialize with LLMs to experiment with and optional prompt. """Initialize with LLMs to experiment with and optional prompt.

@ -1,9 +1,9 @@
"""Utils for LLM Tests.""" """Utils for LLM Tests."""
from langchain.llms.base import LLM from langchain.llms.base import BaseLLM
def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None: def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None:
"""Assert LLM Equality for tests.""" """Assert LLM Equality for tests."""
# Check that they are the same type. # Check that they are the same type.
assert type(llm) == type(loaded_llm) assert type(llm) == type(loaded_llm)

Loading…
Cancel
Save