fix typing (#1807)

This commit is contained in:
Harrison Chase 2023-03-20 07:50:49 -07:00 committed by GitHub
parent b6ba989f2f
commit b1c4480d7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 47 additions and 38 deletions

View File

@ -8,9 +8,9 @@ from pydantic import BaseModel, Field, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
from langchain.requests import RequestsWrapper from langchain.requests import RequestsWrapper
from langchain.schema import BaseLanguageModel
class APIChain(Chain, BaseModel): class APIChain(Chain, BaseModel):
@ -84,7 +84,7 @@ class APIChain(Chain, BaseModel):
@classmethod @classmethod
def from_llm_and_api_docs( def from_llm_and_api_docs(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
api_docs: str, api_docs: str,
headers: Optional[dict] = None, headers: Optional[dict] = None,
api_url_prompt: BasePromptTemplate = API_URL_PROMPT, api_url_prompt: BasePromptTemplate = API_URL_PROMPT,

View File

@ -5,8 +5,8 @@ from langchain.chains.base import Chain
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class ConstitutionalChain(Chain): class ConstitutionalChain(Chain):
@ -45,7 +45,7 @@ class ConstitutionalChain(Chain):
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
chain: LLMChain, chain: LLMChain,
critique_prompt: BasePromptTemplate = CRITIQUE_PROMPT, critique_prompt: BasePromptTemplate = CRITIQUE_PROMPT,
revision_prompt: BasePromptTemplate = REVISION_PROMPT, revision_prompt: BasePromptTemplate = REVISION_PROMPT,

View File

@ -6,8 +6,8 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT from langchain.chains.llm_bash.prompt import PROMPT
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.utilities.bash import BashProcess from langchain.utilities.bash import BashProcess
@ -21,7 +21,7 @@ class LLMBashChain(Chain, BaseModel):
llm_bash = LLMBashChain(llm=OpenAI()) llm_bash = LLMBashChain(llm=OpenAI())
""" """
llm: BaseLLM llm: BaseLanguageModel
"""LLM wrapper to use.""" """LLM wrapper to use."""
input_key: str = "question" #: :meta private: input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private: output_key: str = "answer" #: :meta private:

View File

@ -12,15 +12,15 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.python import PythonREPL from langchain.python import PythonREPL
from langchain.schema import BaseLanguageModel
class PALChain(Chain, BaseModel): class PALChain(Chain, BaseModel):
"""Implements Program-Aided Language Models.""" """Implements Program-Aided Language Models."""
llm: BaseLLM llm: BaseLanguageModel
prompt: BasePromptTemplate prompt: BasePromptTemplate
stop: str = "\n\n" stop: str = "\n\n"
get_answer_expr: str = "print(solution())" get_answer_expr: str = "print(solution())"
@ -68,7 +68,7 @@ class PALChain(Chain, BaseModel):
return output return output
@classmethod @classmethod
def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain: def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain:
"""Load PAL from math prompt.""" """Load PAL from math prompt."""
return cls( return cls(
llm=llm, llm=llm,
@ -79,7 +79,9 @@ class PALChain(Chain, BaseModel):
) )
@classmethod @classmethod
def from_colored_object_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain: def from_colored_object_prompt(
cls, llm: BaseLanguageModel, **kwargs: Any
) -> PALChain:
"""Load PAL from colored object prompt.""" """Load PAL from colored object prompt."""
return cls( return cls(
llm=llm, llm=llm,

View File

@ -19,8 +19,8 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
QUESTION_PROMPT, QUESTION_PROMPT,
) )
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class BaseQAWithSourcesChain(Chain, BaseModel, ABC): class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
@ -38,7 +38,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
document_prompt: BasePromptTemplate = EXAMPLE_PROMPT, document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
question_prompt: BasePromptTemplate = QUESTION_PROMPT, question_prompt: BasePromptTemplate = QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = COMBINE_PROMPT, combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
@ -65,7 +65,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
@classmethod @classmethod
def from_chain_type( def from_chain_type(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
chain_type: str = "stuff", chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None, chain_type_kwargs: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,

View File

@ -13,19 +13,21 @@ from langchain.chains.qa_with_sources import (
stuff_prompt, stuff_prompt,
) )
from langchain.chains.question_answering import map_rerank_prompt from langchain.chains.question_answering import map_rerank_prompt
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol): class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain.""" """Interface for loading the combine documents chain."""
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain: def __call__(
self, llm: BaseLanguageModel, **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain.""" """Callable to load the combine documents chain."""
def _load_map_rerank_chain( def _load_map_rerank_chain(
llm: BaseLLM, llm: BaseLanguageModel,
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT, prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
verbose: bool = False, verbose: bool = False,
document_variable_name: str = "context", document_variable_name: str = "context",
@ -44,7 +46,7 @@ def _load_map_rerank_chain(
def _load_stuff_chain( def _load_stuff_chain(
llm: BaseLLM, llm: BaseLanguageModel,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT, document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT,
document_variable_name: str = "summaries", document_variable_name: str = "summaries",
@ -62,15 +64,15 @@ def _load_stuff_chain(
def _load_map_reduce_chain( def _load_map_reduce_chain(
llm: BaseLLM, llm: BaseLanguageModel,
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT, document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
combine_document_variable_name: str = "summaries", combine_document_variable_name: str = "summaries",
map_reduce_document_variable_name: str = "context", map_reduce_document_variable_name: str = "context",
collapse_prompt: Optional[BasePromptTemplate] = None, collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None, reduce_llm: Optional[BaseLanguageModel] = None,
collapse_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
@ -112,13 +114,13 @@ def _load_map_reduce_chain(
def _load_refine_chain( def _load_refine_chain(
llm: BaseLLM, llm: BaseLanguageModel,
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT, document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
document_variable_name: str = "context_str", document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer", initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None, refine_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> RefineDocumentsChain: ) -> RefineDocumentsChain:
@ -137,7 +139,7 @@ def _load_refine_chain(
def load_qa_with_sources_chain( def load_qa_with_sources_chain(
llm: BaseLLM, llm: BaseLanguageModel,
chain_type: str = "stuff", chain_type: str = "stuff",
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,

View File

@ -8,8 +8,8 @@ from pydantic import BaseModel, Extra, Field
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.sql_database import SQLDatabase from langchain.sql_database import SQLDatabase
@ -24,7 +24,7 @@ class SQLDatabaseChain(Chain, BaseModel):
db_chain = SQLDatabaseChain(llm=OpenAI(), database=db) db_chain = SQLDatabaseChain(llm=OpenAI(), database=db)
""" """
llm: BaseLLM llm: BaseLanguageModel
"""LLM wrapper to use.""" """LLM wrapper to use."""
database: SQLDatabase = Field(exclude=True) database: SQLDatabase = Field(exclude=True)
"""SQL Database to connect to.""" """SQL Database to connect to."""
@ -122,7 +122,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
database: SQLDatabase, database: SQLDatabase,
query_prompt: BasePromptTemplate = PROMPT, query_prompt: BasePromptTemplate = PROMPT,
decider_prompt: BasePromptTemplate = DECIDER_PROMPT, decider_prompt: BasePromptTemplate = DECIDER_PROMPT,

View File

@ -7,19 +7,21 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol): class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain.""" """Interface for loading the combine documents chain."""
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain: def __call__(
self, llm: BaseLanguageModel, **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain.""" """Callable to load the combine documents chain."""
def _load_stuff_chain( def _load_stuff_chain(
llm: BaseLLM, llm: BaseLanguageModel,
prompt: BasePromptTemplate = stuff_prompt.PROMPT, prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "text", document_variable_name: str = "text",
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
@ -36,14 +38,14 @@ def _load_stuff_chain(
def _load_map_reduce_chain( def _load_map_reduce_chain(
llm: BaseLLM, llm: BaseLanguageModel,
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
combine_document_variable_name: str = "text", combine_document_variable_name: str = "text",
map_reduce_document_variable_name: str = "text", map_reduce_document_variable_name: str = "text",
collapse_prompt: Optional[BasePromptTemplate] = None, collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None, reduce_llm: Optional[BaseLanguageModel] = None,
collapse_llm: Optional[BaseLLM] = None, collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
@ -84,12 +86,12 @@ def _load_map_reduce_chain(
def _load_refine_chain( def _load_refine_chain(
llm: BaseLLM, llm: BaseLanguageModel,
question_prompt: BasePromptTemplate = refine_prompts.PROMPT, question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT, refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
document_variable_name: str = "text", document_variable_name: str = "text",
initial_response_name: str = "existing_answer", initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None, refine_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> RefineDocumentsChain: ) -> RefineDocumentsChain:
@ -107,7 +109,7 @@ def _load_refine_chain(
def load_summarize_chain( def load_summarize_chain(
llm: BaseLLM, llm: BaseLanguageModel,
chain_type: str = "stuff", chain_type: str = "stuff",
verbose: Optional[bool] = None, verbose: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,

View File

@ -11,8 +11,8 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
@ -103,7 +103,10 @@ class VectorDBQA(Chain, BaseModel):
@classmethod @classmethod
def from_llm( def from_llm(
cls, llm: BaseLLM, prompt: Optional[PromptTemplate] = None, **kwargs: Any cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
**kwargs: Any,
) -> VectorDBQA: ) -> VectorDBQA:
"""Initialize from LLM.""" """Initialize from LLM."""
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
@ -122,7 +125,7 @@ class VectorDBQA(Chain, BaseModel):
@classmethod @classmethod
def from_chain_type( def from_chain_type(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
chain_type: str = "stuff", chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None, chain_type_kwargs: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,