diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 5036a72c..5cbded4e 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -8,9 +8,9 @@ from pydantic import BaseModel, Field, root_validator from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.prompts import BasePromptTemplate from langchain.requests import RequestsWrapper +from langchain.schema import BaseLanguageModel class APIChain(Chain, BaseModel): @@ -84,7 +84,7 @@ class APIChain(Chain, BaseModel): @classmethod def from_llm_and_api_docs( cls, - llm: BaseLLM, + llm: BaseLanguageModel, api_docs: str, headers: Optional[dict] = None, api_url_prompt: BasePromptTemplate = API_URL_PROMPT, diff --git a/langchain/chains/constitutional_ai/base.py b/langchain/chains/constitutional_ai/base.py index b68e22de..b78aa3a0 100644 --- a/langchain/chains/constitutional_ai/base.py +++ b/langchain/chains/constitutional_ai/base.py @@ -5,8 +5,8 @@ from langchain.chains.base import Chain from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT from langchain.chains.llm import LLMChain -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BaseLanguageModel class ConstitutionalChain(Chain): @@ -45,7 +45,7 @@ class ConstitutionalChain(Chain): @classmethod def from_llm( cls, - llm: BaseLLM, + llm: BaseLanguageModel, chain: LLMChain, critique_prompt: BasePromptTemplate = CRITIQUE_PROMPT, revision_prompt: BasePromptTemplate = REVISION_PROMPT, diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 5a0d88eb..994df302 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -6,8 +6,8 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.prompt import PROMPT -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BaseLanguageModel from langchain.utilities.bash import BashProcess @@ -21,7 +21,7 @@ class LLMBashChain(Chain, BaseModel): llm_bash = LLMBashChain(llm=OpenAI()) """ - llm: BaseLLM + llm: BaseLanguageModel """LLM wrapper to use.""" input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 5f574aaa..443dd137 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -12,15 +12,15 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.python import PythonREPL +from langchain.schema import BaseLanguageModel class PALChain(Chain, BaseModel): """Implements Program-Aided Language Models.""" - llm: BaseLLM + llm: BaseLanguageModel prompt: BasePromptTemplate stop: str = "\n\n" get_answer_expr: str = "print(solution())" @@ -68,7 +68,7 @@ class PALChain(Chain, BaseModel): return output @classmethod - def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain: + def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain: """Load PAL from math prompt.""" return cls( llm=llm, @@ -79,7 +79,9 @@ class PALChain(Chain, BaseModel): ) @classmethod - def from_colored_object_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain: + def from_colored_object_prompt( + cls, llm: BaseLanguageModel, **kwargs: Any + ) -> PALChain: """Load PAL from colored object prompt.""" return cls( llm=llm, diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 3cc15147..628a7b36 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -19,8 +19,8 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import ( QUESTION_PROMPT, ) from langchain.docstore.document import Document -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BaseLanguageModel class BaseQAWithSourcesChain(Chain, BaseModel, ABC): @@ -38,7 +38,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): @classmethod def from_llm( cls, - llm: BaseLLM, + llm: BaseLanguageModel, document_prompt: BasePromptTemplate = EXAMPLE_PROMPT, question_prompt: BasePromptTemplate = QUESTION_PROMPT, combine_prompt: BasePromptTemplate = COMBINE_PROMPT, @@ -65,7 +65,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): @classmethod def from_chain_type( cls, - llm: BaseLLM, + llm: BaseLanguageModel, chain_type: str = "stuff", chain_type_kwargs: Optional[dict] = None, **kwargs: Any, diff --git a/langchain/chains/qa_with_sources/loading.py b/langchain/chains/qa_with_sources/loading.py index 4af17329..c1d923ae 100644 --- a/langchain/chains/qa_with_sources/loading.py +++ b/langchain/chains/qa_with_sources/loading.py @@ -13,19 +13,21 @@ from langchain.chains.qa_with_sources import ( stuff_prompt, ) from langchain.chains.question_answering import map_rerank_prompt -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BaseLanguageModel class LoadingCallable(Protocol): """Interface for loading the combine documents chain.""" - def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain: + def __call__( + self, llm: BaseLanguageModel, **kwargs: Any + ) -> BaseCombineDocumentsChain: """Callable to load the combine documents chain.""" def _load_map_rerank_chain( - llm: BaseLLM, + llm: BaseLanguageModel, prompt: BasePromptTemplate = map_rerank_prompt.PROMPT, verbose: bool = False, document_variable_name: str = "context", @@ -44,7 +46,7 @@ def _load_map_rerank_chain( def _load_stuff_chain( - llm: BaseLLM, + llm: BaseLanguageModel, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT, document_variable_name: str = "summaries", @@ -62,15 +64,15 @@ def _load_stuff_chain( def _load_map_reduce_chain( - llm: BaseLLM, + llm: BaseLanguageModel, question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT, combine_document_variable_name: str = "summaries", map_reduce_document_variable_name: str = "context", collapse_prompt: Optional[BasePromptTemplate] = None, - reduce_llm: Optional[BaseLLM] = None, - collapse_llm: Optional[BaseLLM] = None, + reduce_llm: Optional[BaseLanguageModel] = None, + collapse_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: @@ -112,13 +114,13 @@ def _load_map_reduce_chain( def _load_refine_chain( - llm: BaseLLM, + llm: BaseLanguageModel, question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT, document_variable_name: str = "context_str", initial_response_name: str = "existing_answer", - refine_llm: Optional[BaseLLM] = None, + refine_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, **kwargs: Any, ) -> RefineDocumentsChain: @@ -137,7 +139,7 @@ def _load_refine_chain( def load_qa_with_sources_chain( - llm: BaseLLM, + llm: BaseLanguageModel, chain_type: str = "stuff", verbose: Optional[bool] = None, **kwargs: Any, diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 37a3bd17..6f959014 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -8,8 +8,8 @@ from pydantic import BaseModel, Extra, Field from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BaseLanguageModel from langchain.sql_database import SQLDatabase @@ -24,7 +24,7 @@ class SQLDatabaseChain(Chain, BaseModel): db_chain = SQLDatabaseChain(llm=OpenAI(), database=db) """ - llm: BaseLLM + llm: BaseLanguageModel """LLM wrapper to use.""" database: SQLDatabase = Field(exclude=True) """SQL Database to connect to.""" @@ -122,7 +122,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel): @classmethod def from_llm( cls, - llm: BaseLLM, + llm: BaseLanguageModel, database: SQLDatabase, query_prompt: BasePromptTemplate = PROMPT, decider_prompt: BasePromptTemplate = DECIDER_PROMPT, diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index c6446f68..c31fda47 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -7,19 +7,21 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BaseLanguageModel class LoadingCallable(Protocol): """Interface for loading the combine documents chain.""" - def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain: + def __call__( + self, llm: BaseLanguageModel, **kwargs: Any + ) -> BaseCombineDocumentsChain: """Callable to load the combine documents chain.""" def _load_stuff_chain( - llm: BaseLLM, + llm: BaseLanguageModel, prompt: BasePromptTemplate = stuff_prompt.PROMPT, document_variable_name: str = "text", verbose: Optional[bool] = None, @@ -36,14 +38,14 @@ def _load_stuff_chain( def _load_map_reduce_chain( - llm: BaseLLM, + llm: BaseLanguageModel, map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, combine_document_variable_name: str = "text", map_reduce_document_variable_name: str = "text", collapse_prompt: Optional[BasePromptTemplate] = None, - reduce_llm: Optional[BaseLLM] = None, - collapse_llm: Optional[BaseLLM] = None, + reduce_llm: Optional[BaseLanguageModel] = None, + collapse_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: @@ -84,12 +86,12 @@ def _load_map_reduce_chain( def _load_refine_chain( - llm: BaseLLM, + llm: BaseLanguageModel, question_prompt: BasePromptTemplate = refine_prompts.PROMPT, refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT, document_variable_name: str = "text", initial_response_name: str = "existing_answer", - refine_llm: Optional[BaseLLM] = None, + refine_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, **kwargs: Any, ) -> RefineDocumentsChain: @@ -107,7 +109,7 @@ def _load_refine_chain( def load_summarize_chain( - llm: BaseLLM, + llm: BaseLanguageModel, chain_type: str = "stuff", verbose: Optional[bool] = None, **kwargs: Any, diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index 3da04659..16182b78 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -11,8 +11,8 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR -from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate +from langchain.schema import BaseLanguageModel from langchain.vectorstores.base import VectorStore @@ -103,7 +103,10 @@ class VectorDBQA(Chain, BaseModel): @classmethod def from_llm( - cls, llm: BaseLLM, prompt: Optional[PromptTemplate] = None, **kwargs: Any + cls, + llm: BaseLanguageModel, + prompt: Optional[PromptTemplate] = None, + **kwargs: Any, ) -> VectorDBQA: """Initialize from LLM.""" _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) @@ -122,7 +125,7 @@ class VectorDBQA(Chain, BaseModel): @classmethod def from_chain_type( cls, - llm: BaseLLM, + llm: BaseLanguageModel, chain_type: str = "stuff", chain_type_kwargs: Optional[dict] = None, **kwargs: Any,