diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index fb60115c..d08bbe18 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -11,7 +11,7 @@ from langchain.agents.tools import Tool from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.schema import AgentAction @@ -87,7 +87,9 @@ class Agent(Chain, BaseModel, ABC): pass @classmethod - def from_llm_and_tools(cls, llm: LLM, tools: List[Tool], **kwargs: Any) -> Agent: + def from_llm_and_tools( + cls, llm: BaseLLM, tools: List[Tool], **kwargs: Any + ) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index 75d24d1f..98b644ed 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -6,7 +6,7 @@ from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM AGENT_TO_CLASS = { "zero-shot-react-description": ZeroShotAgent, @@ -17,7 +17,7 @@ AGENT_TO_CLASS = { def initialize_agent( tools: List[Tool], - llm: LLM, + llm: BaseLLM, agent: str = "zero-shot-react-description", **kwargs: Any, ) -> Agent: diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 1519c38d..98c1845f 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -6,7 +6,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, Tuple from langchain.agents.agent import Agent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.tools import Tool -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate FINAL_ANSWER_ACTION = "Final Answer: " @@ -116,7 +116,9 @@ class MRKLChain(ZeroShotAgent): """ @classmethod - def from_chains(cls, llm: LLM, chains: List[ChainConfig], **kwargs: Any) -> Agent: + def from_chains( + cls, llm: BaseLLM, chains: List[ChainConfig], **kwargs: Any + ) -> Agent: """User friendly way to initialize the MRKL chain. This is intended to be an easy way to get up and running with the diff --git a/langchain/agents/react/base.py b/langchain/agents/react/base.py index ca380e1a..41cb9f02 100644 --- a/langchain/agents/react/base.py +++ b/langchain/agents/react/base.py @@ -11,7 +11,7 @@ from langchain.agents.tools import Tool from langchain.chains.llm import LLMChain from langchain.docstore.base import Docstore from langchain.docstore.document import Document -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate @@ -123,7 +123,7 @@ class ReActChain(ReActDocstoreAgent): react = ReAct(llm=OpenAI()) """ - def __init__(self, llm: LLM, docstore: Docstore, **kwargs: Any): + def __init__(self, llm: BaseLLM, docstore: Docstore, **kwargs: Any): """Initialize with the LLM and a docstore.""" docstore_explorer = DocstoreExplorer(docstore) tools = [ diff --git a/langchain/agents/self_ask_with_search/base.py b/langchain/agents/self_ask_with_search/base.py index d8184fb4..0ac8573d 100644 --- a/langchain/agents/self_ask_with_search/base.py +++ b/langchain/agents/self_ask_with_search/base.py @@ -5,7 +5,7 @@ from langchain.agents.agent import Agent from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.tools import Tool from langchain.chains.llm import LLMChain -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.serpapi import SerpAPIWrapper @@ -72,7 +72,7 @@ class SelfAskWithSearchChain(SelfAskWithSearchAgent): self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain) """ - def __init__(self, llm: LLM, search_chain: SerpAPIWrapper, **kwargs: Any): + def __init__(self, llm: BaseLLM, search_chain: SerpAPIWrapper, **kwargs: Any): """Initialize with just an LLM and a search chain.""" search_tool = Tool(name="Intermediate Answer", func=search_chain.run) llm_chain = LLMChain(llm=llm, prompt=PROMPT) diff --git a/langchain/cache.py b/langchain/cache.py index 258c0383..afa0eb67 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -1,6 +1,6 @@ """Beta Feature: base interface for cache.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple from sqlalchemy import Column, Integer, String, create_engine, select from sqlalchemy.engine.base import Engine @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from langchain.schema import Generation -RETURN_VAL_TYPE = Union[List[Generation], str] +RETURN_VAL_TYPE = List[Generation] class BaseCache(ABC): @@ -43,15 +43,6 @@ class InMemoryCache(BaseCache): Base = declarative_base() -class LLMCache(Base): # type: ignore - """SQLite table for simple LLM cache (string only).""" - - __tablename__ = "llm_cache" - prompt = Column(String, primary_key=True) - llm = Column(String, primary_key=True) - response = Column(String) - - class FullLLMCache(Base): # type: ignore """SQLite table for full LLM Cache (all generations).""" @@ -84,29 +75,16 @@ class SQLAlchemyCache(BaseCache): generations.append(Generation(text=row[0])) if len(generations) > 0: return generations - stmt = ( - select(LLMCache.response) - .where(LLMCache.prompt == prompt) - .where(LLMCache.llm == llm_string) - ) - with Session(self.engine) as session: - for row in session.execute(stmt): - return row[0] return None def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Look up based on prompt and llm_string.""" - if isinstance(return_val, str): - item = LLMCache(prompt=prompt, llm=llm_string, response=return_val) + for i, generation in enumerate(return_val): + item = FullLLMCache( + prompt=prompt, llm=llm_string, response=generation.text, idx=i + ) with Session(self.engine) as session, session.begin(): session.add(item) - else: - for i, generation in enumerate(return_val): - item = FullLLMCache( - prompt=prompt, llm=llm_string, response=generation.text, idx=i - ) - with Session(self.engine) as session, session.begin(): - session.add(item) class SQLiteCache(SQLAlchemyCache): diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 0eac3817..07591bd4 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -9,7 +9,7 @@ from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import print_text -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.requests import RequestsWrapper @@ -81,7 +81,7 @@ class APIChain(Chain, BaseModel): @classmethod def from_llm_and_api_docs( - cls, llm: LLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any + cls, llm: BaseLLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any ) -> APIChain: """Load chain from just an LLM and the api docs.""" get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT) diff --git a/langchain/chains/conversation/memory.py b/langchain/chains/conversation/memory.py index ebdbb7ef..0a686dde 100644 --- a/langchain/chains/conversation/memory.py +++ b/langchain/chains/conversation/memory.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, root_validator from langchain.chains.base import Memory from langchain.chains.conversation.prompt import SUMMARY_PROMPT from langchain.chains.llm import LLMChain -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate @@ -88,7 +88,7 @@ class ConversationSummaryMemory(Memory, BaseModel): """Conversation summarizer to memory.""" buffer: str = "" - llm: LLM + llm: BaseLLM prompt: BasePromptTemplate = SUMMARY_PROMPT memory_key: str = "history" #: :meta private: diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 31800774..138b62a5 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Extra import langchain from langchain.chains.base import Chain -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate @@ -25,7 +25,7 @@ class LLMChain(Chain, BaseModel): prompt: BasePromptTemplate """Prompt object to use.""" - llm: LLM + llm: BaseLLM """LLM wrapper to use.""" output_key: str = "text" #: :meta private: diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 85ea58ce..76963713 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -7,7 +7,7 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.prompt import PROMPT from langchain.input import print_text -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.utilities.bash import BashProcess @@ -21,7 +21,7 @@ class LLMBashChain(Chain, BaseModel): llm_bash = LLMBashChain(llm=OpenAI()) """ - llm: LLM + llm: BaseLLM """LLM wrapper to use.""" input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: diff --git a/langchain/chains/llm_checker/base.py b/langchain/chains/llm_checker/base.py index 4dfc1ba3..c0192d79 100644 --- a/langchain/chains/llm_checker/base.py +++ b/langchain/chains/llm_checker/base.py @@ -14,7 +14,7 @@ from langchain.chains.llm_checker.prompt import ( REVISED_ANSWER_PROMPT, ) from langchain.chains.sequential import SequentialChain -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate @@ -28,7 +28,7 @@ class LLMCheckerChain(Chain, BaseModel): checker_chain = LLMCheckerChain(llm=llm) """ - llm: LLM + llm: BaseLLM """LLM wrapper to use.""" create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index 5ebd9051..a0485e9e 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -7,7 +7,7 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.prompt import PROMPT from langchain.input import print_text -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.python import PythonREPL @@ -21,7 +21,7 @@ class LLMMathChain(Chain, BaseModel): llm_math = LLMMathChain(llm=OpenAI()) """ - llm: LLM + llm: BaseLLM """LLM wrapper to use.""" input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index bb518587..ea01ab54 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -15,7 +15,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.text_splitter import TextSplitter @@ -32,7 +32,7 @@ class MapReduceChain(Chain, BaseModel): @classmethod def from_params( - cls, llm: LLM, prompt: BasePromptTemplate, text_splitter: TextSplitter + cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter ) -> MapReduceChain: """Construct a map-reduce chain that uses the chain for map and reduce.""" llm_chain = LLMChain(llm=llm, prompt=prompt) diff --git a/langchain/chains/natbot/base.py b/langchain/chains/natbot/base.py index 744c4bc8..6e75946f 100644 --- a/langchain/chains/natbot/base.py +++ b/langchain/chains/natbot/base.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.natbot.prompt import PROMPT -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.llms.openai import OpenAI @@ -22,7 +22,7 @@ class NatBotChain(Chain, BaseModel): natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.") """ - llm: LLM + llm: BaseLLM """LLM wrapper to use.""" objective: str """Objective that NatBot is tasked with completing.""" diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 1077c958..08335e2f 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -13,7 +13,7 @@ from langchain.chains.llm import LLMChain from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT from langchain.input import print_text -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.python import PythonREPL @@ -21,7 +21,7 @@ from langchain.python import PythonREPL class PALChain(Chain, BaseModel): """Implements Program-Aided Language Models.""" - llm: LLM + llm: BaseLLM prompt: BasePromptTemplate stop: str = "\n\n" get_answer_expr: str = "print(solution())" @@ -59,7 +59,7 @@ class PALChain(Chain, BaseModel): return {self.output_key: res.strip()} @classmethod - def from_math_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain: + def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain: """Load PAL from math prompt.""" return cls( llm=llm, @@ -70,7 +70,7 @@ class PALChain(Chain, BaseModel): ) @classmethod - def from_colored_object_prompt(cls, llm: LLM, **kwargs: Any) -> PALChain: + def from_colored_object_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain: """Load PAL from colored object prompt.""" return cls( llm=llm, diff --git a/langchain/chains/qa_with_sources/__init__.py b/langchain/chains/qa_with_sources/__init__.py index 82d93f50..b56997d3 100644 --- a/langchain/chains/qa_with_sources/__init__.py +++ b/langchain/chains/qa_with_sources/__init__.py @@ -11,19 +11,19 @@ from langchain.chains.qa_with_sources import ( refine_prompts, stuff_prompt, ) -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate class LoadingCallable(Protocol): """Interface for loading the combine documents chain.""" - def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain: + def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain: """Callable to load the combine documents chain.""" def _load_stuff_chain( - llm: LLM, + llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "summaries", **kwargs: Any, @@ -38,7 +38,7 @@ def _load_stuff_chain( def _load_map_reduce_chain( - llm: LLM, + llm: BaseLLM, question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT, @@ -72,7 +72,7 @@ def _load_map_reduce_chain( def _load_refine_chain( - llm: LLM, + llm: BaseLLM, question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT, @@ -93,7 +93,7 @@ def _load_refine_chain( def load_qa_with_sources_chain( - llm: LLM, chain_type: str = "stuff", **kwargs: Any + llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any ) -> BaseCombineDocumentsChain: """Load question answering with sources chain. diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 249d0c16..48a3d017 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -18,7 +18,7 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import ( QUESTION_PROMPT, ) from langchain.docstore.document import Document -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate @@ -35,7 +35,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): @classmethod def from_llm( cls, - llm: LLM, + llm: BaseLLM, document_prompt: BasePromptTemplate = EXAMPLE_PROMPT, question_prompt: BasePromptTemplate = QUESTION_PROMPT, combine_prompt: BasePromptTemplate = COMBINE_PROMPT, diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 9883e068..c9fd93b1 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -11,19 +11,19 @@ from langchain.chains.question_answering import ( refine_prompts, stuff_prompt, ) -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate class LoadingCallable(Protocol): """Interface for loading the combine documents chain.""" - def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain: + def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain: """Callable to load the combine documents chain.""" def _load_stuff_chain( - llm: LLM, + llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "context", **kwargs: Any, @@ -36,7 +36,7 @@ def _load_stuff_chain( def _load_map_reduce_chain( - llm: LLM, + llm: BaseLLM, question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, combine_document_variable_name: str = "summaries", @@ -67,7 +67,7 @@ def _load_map_reduce_chain( def _load_refine_chain( - llm: LLM, + llm: BaseLLM, question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, document_variable_name: str = "context_str", @@ -86,7 +86,7 @@ def _load_refine_chain( def load_qa_chain( - llm: LLM, chain_type: str = "stuff", **kwargs: Any + llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any ) -> BaseCombineDocumentsChain: """Load question answering chain. diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 7ceab5fb..103993f3 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -7,7 +7,7 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sql_database.prompt import PROMPT from langchain.input import print_text -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.sql_database import SQLDatabase @@ -22,7 +22,7 @@ class SQLDatabaseChain(Chain, BaseModel): db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db) """ - llm: LLM + llm: BaseLLM """LLM wrapper to use.""" database: SQLDatabase """SQL Database to connect to.""" diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index e613ff8d..2fca7f2e 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -7,19 +7,19 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate class LoadingCallable(Protocol): """Interface for loading the combine documents chain.""" - def __call__(self, llm: LLM, **kwargs: Any) -> BaseCombineDocumentsChain: + def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain: """Callable to load the combine documents chain.""" def _load_stuff_chain( - llm: LLM, + llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "text", **kwargs: Any, @@ -32,7 +32,7 @@ def _load_stuff_chain( def _load_map_reduce_chain( - llm: LLM, + llm: BaseLLM, map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, combine_document_variable_name: str = "text", @@ -63,7 +63,7 @@ def _load_map_reduce_chain( def _load_refine_chain( - llm: LLM, + llm: BaseLLM, question_prompt: BasePromptTemplate = refine_prompts.PROMPT, refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT, document_variable_name: str = "text", @@ -82,7 +82,7 @@ def _load_refine_chain( def load_summarize_chain( - llm: LLM, chain_type: str = "stuff", **kwargs: Any + llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any ) -> BaseCombineDocumentsChain: """Load summarizing chain. diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index 30cc1215..d6f7ab6f 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -10,7 +10,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.vector_db_qa.prompt import PROMPT -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate from langchain.vectorstores.base import VectorStore @@ -84,7 +84,7 @@ class VectorDBQA(Chain, BaseModel): @classmethod def from_llm( - cls, llm: LLM, prompt: PromptTemplate = PROMPT, **kwargs: Any + cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any ) -> VectorDBQA: """Initialize from LLM.""" llm_chain = LLMChain(llm=llm, prompt=prompt) diff --git a/langchain/example_generator.py b/langchain/example_generator.py index 58816e56..7c309d05 100644 --- a/langchain/example_generator.py +++ b/langchain/example_generator.py @@ -2,7 +2,7 @@ from typing import List from langchain.chains.llm import LLMChain -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate @@ -10,7 +10,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example." def generate_example( - examples: List[dict], llm: LLM, prompt_template: PromptTemplate + examples: List[dict], llm: BaseLLM, prompt_template: PromptTemplate ) -> str: """Return another example given a list of examples for a prompt.""" prompt = FewShotPromptTemplate( diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index f32ef72d..72696af4 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -2,7 +2,7 @@ from typing import Dict, Type from langchain.llms.ai21 import AI21 -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.llms.cohere import Cohere from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.llms.huggingface_pipeline import HuggingFacePipeline @@ -18,7 +18,7 @@ __all__ = [ "AI21", ] -type_to_cls_dict: Dict[str, Type[LLM]] = { +type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "ai21": AI21, "cohere": Cohere, "huggingface_hub": HuggingFaceHub, diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 4488475d..09ebaf06 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -21,7 +21,7 @@ class LLMResult(NamedTuple): """For arbitrary LLM provider specific output.""" -class LLM(BaseModel, ABC): +class BaseLLM(BaseModel, ABC): """LLM wrapper should take in a prompt and return a string.""" class Config: @@ -29,16 +29,11 @@ class LLM(BaseModel, ABC): extra = Extra.forbid + @abstractmethod def _generate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: - """Run the LLM on the given prompt and input.""" - # TODO: add caching here. - generations = [] - for prompt in prompts: - text = self(prompt, stop=stop) - generations.append([Generation(text=text)]) - return LLMResult(generations=generations) + """Run the LLM on the given prompts.""" def generate( self, prompts: List[str], stop: Optional[List[str]] = None @@ -88,28 +83,9 @@ class LLM(BaseModel, ABC): # calculate the number of tokens in the tokenized text return len(tokenized_text) - @abstractmethod - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - """Run the LLM on the given prompt and input.""" - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Check Cache and run the LLM on the given prompt and input.""" - if langchain.llm_cache is None: - return self._call(prompt, stop=stop) - params = self._llm_dict() - params["stop"] = stop - llm_string = str(sorted([(k, v) for k, v in params.items()])) - if langchain.cache is not None: - cache_val = langchain.llm_cache.lookup(prompt, llm_string) - if cache_val is not None: - if isinstance(cache_val, str): - return cache_val - else: - return cache_val[0].text - return_val = self._call(prompt, stop=stop) - if langchain.cache is not None: - langchain.llm_cache.update(prompt, llm_string, return_val) - return return_val + return self.generate([prompt], stop=stop).generations[0][0].text @property def _identifying_params(self) -> Mapping[str, Any]: @@ -163,3 +139,26 @@ class LLM(BaseModel, ABC): yaml.dump(prompt_dict, f, default_flow_style=False) else: raise ValueError(f"{save_path} must be json or yaml") + + +class LLM(BaseLLM): + """LLM class that expect subclasses to implement a simpler call method. + + The purpose of this class is to expose a simpler interface for working + with LLMs, rather than expect the user to implement the full _generate method. + """ + + @abstractmethod + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Run the LLM on the given prompt and input.""" + + def _generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + # TODO: add caching here. + generations = [] + for prompt in prompts: + text = self._call(prompt, stop=stop) + generations.append([Generation(text=text)]) + return LLMResult(generations=generations) diff --git a/langchain/llms/loading.py b/langchain/llms/loading.py index d5881ba4..723606be 100644 --- a/langchain/llms/loading.py +++ b/langchain/llms/loading.py @@ -6,10 +6,10 @@ from typing import Union import yaml from langchain.llms import type_to_cls_dict -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM -def load_llm_from_config(config: dict) -> LLM: +def load_llm_from_config(config: dict) -> BaseLLM: """Load LLM from Config Dict.""" if "_type" not in config: raise ValueError("Must specify an LLM Type in config") @@ -22,7 +22,7 @@ def load_llm_from_config(config: dict) -> LLM: return llm_cls(**config) -def load_llm(file: Union[str, Path]) -> LLM: +def load_llm(file: Union[str, Path]) -> BaseLLM: """Load LLM from file.""" # Convert file to Path object. if isinstance(file, str): diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 1e285142..d50795a5 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -4,12 +4,12 @@ from typing import Any, Dict, Generator, List, Mapping, Optional from pydantic import BaseModel, Extra, Field, root_validator -from langchain.llms.base import LLM, LLMResult +from langchain.llms.base import BaseLLM, LLMResult from langchain.schema import Generation from langchain.utils import get_from_dict_or_env -class OpenAI(LLM, BaseModel): +class OpenAI(BaseLLM, BaseModel): """Wrapper around OpenAI large language models. To use, you should have the ``openai`` python package installed, and the @@ -197,23 +197,6 @@ class OpenAI(LLM, BaseModel): """Return type of llm.""" return "openai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - """Call out to OpenAI's create endpoint. - - Args: - prompt: The prompt to pass into the model. - stop: Optional list of stop words to use when generating. - - Returns: - The string generated by the model. - - Example: - .. code-block:: python - - response = openai("Tell me a joke.") - """ - return self.generate([prompt], stop=stop).generations[0][0].text - def get_num_tokens(self, text: str) -> int: """Calculate num tokens with tiktoken package.""" # tiktoken NOT supported for Python 3.8 or below diff --git a/langchain/model_laboratory.py b/langchain/model_laboratory.py index 090c8e52..0ba871b9 100644 --- a/langchain/model_laboratory.py +++ b/langchain/model_laboratory.py @@ -6,7 +6,7 @@ from typing import List, Optional, Sequence from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping, print_text -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM from langchain.prompts.prompt import PromptTemplate @@ -46,7 +46,7 @@ class ModelLaboratory: @classmethod def from_llms( - cls, llms: List[LLM], prompt: Optional[PromptTemplate] = None + cls, llms: List[BaseLLM], prompt: Optional[PromptTemplate] = None ) -> ModelLaboratory: """Initialize with LLMs to experiment with and optional prompt. diff --git a/tests/integration_tests/llms/utils.py b/tests/integration_tests/llms/utils.py index c05445d4..31a27d88 100644 --- a/tests/integration_tests/llms/utils.py +++ b/tests/integration_tests/llms/utils.py @@ -1,9 +1,9 @@ """Utils for LLM Tests.""" -from langchain.llms.base import LLM +from langchain.llms.base import BaseLLM -def assert_llm_equality(llm: LLM, loaded_llm: LLM) -> None: +def assert_llm_equality(llm: BaseLLM, loaded_llm: BaseLLM) -> None: """Assert LLM Equality for tests.""" # Check that they are the same type. assert type(llm) == type(loaded_llm)