fix typing (#1807)

This commit is contained in:
Harrison Chase 2023-03-20 07:50:49 -07:00 committed by GitHub
parent b6ba989f2f
commit b1c4480d7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 47 additions and 38 deletions

View File

@ -8,9 +8,9 @@ from pydantic import BaseModel, Field, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts import BasePromptTemplate
from langchain.requests import RequestsWrapper
from langchain.schema import BaseLanguageModel
class APIChain(Chain, BaseModel):
@ -84,7 +84,7 @@ class APIChain(Chain, BaseModel):
@classmethod
def from_llm_and_api_docs(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
api_docs: str,
headers: Optional[dict] = None,
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,

View File

@ -5,8 +5,8 @@ from langchain.chains.base import Chain
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class ConstitutionalChain(Chain):
@ -45,7 +45,7 @@ class ConstitutionalChain(Chain):
@classmethod
def from_llm(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
chain: LLMChain,
critique_prompt: BasePromptTemplate = CRITIQUE_PROMPT,
revision_prompt: BasePromptTemplate = REVISION_PROMPT,

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.utilities.bash import BashProcess
@ -21,7 +21,7 @@ class LLMBashChain(Chain, BaseModel):
llm_bash = LLMBashChain(llm=OpenAI())
"""
llm: BaseLLM
llm: BaseLanguageModel
"""LLM wrapper to use."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:

View File

@ -12,15 +12,15 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.python import PythonREPL
from langchain.schema import BaseLanguageModel
class PALChain(Chain, BaseModel):
"""Implements Program-Aided Language Models."""
llm: BaseLLM
llm: BaseLanguageModel
prompt: BasePromptTemplate
stop: str = "\n\n"
get_answer_expr: str = "print(solution())"
@ -68,7 +68,7 @@ class PALChain(Chain, BaseModel):
return output
@classmethod
def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain:
"""Load PAL from math prompt."""
return cls(
llm=llm,
@ -79,7 +79,9 @@ class PALChain(Chain, BaseModel):
)
@classmethod
def from_colored_object_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
def from_colored_object_prompt(
cls, llm: BaseLanguageModel, **kwargs: Any
) -> PALChain:
"""Load PAL from colored object prompt."""
return cls(
llm=llm,

View File

@ -19,8 +19,8 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
QUESTION_PROMPT,
)
from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
@ -38,7 +38,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
@classmethod
def from_llm(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
question_prompt: BasePromptTemplate = QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
@ -65,7 +65,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
@classmethod
def from_chain_type(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,

View File

@ -13,19 +13,21 @@ from langchain.chains.qa_with_sources import (
stuff_prompt,
)
from langchain.chains.question_answering import map_rerank_prompt
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
def __call__(
self, llm: BaseLanguageModel, **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_map_rerank_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
verbose: bool = False,
document_variable_name: str = "context",
@ -44,7 +46,7 @@ def _load_map_rerank_chain(
def _load_stuff_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT,
document_variable_name: str = "summaries",
@ -62,15 +64,15 @@ def _load_stuff_chain(
def _load_map_reduce_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
combine_document_variable_name: str = "summaries",
map_reduce_document_variable_name: str = "context",
collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None,
reduce_llm: Optional[BaseLanguageModel] = None,
collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
@ -112,13 +114,13 @@ def _load_map_reduce_chain(
def _load_refine_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None,
refine_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
@ -137,7 +139,7 @@ def _load_refine_chain(
def load_qa_with_sources_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
**kwargs: Any,

View File

@ -8,8 +8,8 @@ from pydantic import BaseModel, Extra, Field
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.sql_database import SQLDatabase
@ -24,7 +24,7 @@ class SQLDatabaseChain(Chain, BaseModel):
db_chain = SQLDatabaseChain(llm=OpenAI(), database=db)
"""
llm: BaseLLM
llm: BaseLanguageModel
"""LLM wrapper to use."""
database: SQLDatabase = Field(exclude=True)
"""SQL Database to connect to."""
@ -122,7 +122,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
@classmethod
def from_llm(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
database: SQLDatabase,
query_prompt: BasePromptTemplate = PROMPT,
decider_prompt: BasePromptTemplate = DECIDER_PROMPT,

View File

@ -7,19 +7,21 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
def __call__(
self, llm: BaseLanguageModel, **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "text",
verbose: Optional[bool] = None,
@ -36,14 +38,14 @@ def _load_stuff_chain(
def _load_map_reduce_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_document_variable_name: str = "text",
map_reduce_document_variable_name: str = "text",
collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None,
reduce_llm: Optional[BaseLanguageModel] = None,
collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
@ -84,12 +86,12 @@ def _load_map_reduce_chain(
def _load_refine_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
document_variable_name: str = "text",
initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None,
refine_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
@ -107,7 +109,7 @@ def _load_refine_chain(
def load_summarize_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
**kwargs: Any,

View File

@ -11,8 +11,8 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.vectorstores.base import VectorStore
@ -103,7 +103,10 @@ class VectorDBQA(Chain, BaseModel):
@classmethod
def from_llm(
cls, llm: BaseLLM, prompt: Optional[PromptTemplate] = None, **kwargs: Any
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
**kwargs: Any,
) -> VectorDBQA:
"""Initialize from LLM."""
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
@ -122,7 +125,7 @@ class VectorDBQA(Chain, BaseModel):
@classmethod
def from_chain_type(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,