forked from Archives/langchain
fix typing (#1807)
This commit is contained in:
parent
b6ba989f2f
commit
b1c4480d7c
@ -8,9 +8,9 @@ from pydantic import BaseModel, Field, root_validator
|
||||
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.requests import RequestsWrapper
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class APIChain(Chain, BaseModel):
|
||||
@ -84,7 +84,7 @@ class APIChain(Chain, BaseModel):
|
||||
@classmethod
|
||||
def from_llm_and_api_docs(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
api_docs: str,
|
||||
headers: Optional[dict] = None,
|
||||
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
|
||||
|
@ -5,8 +5,8 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class ConstitutionalChain(Chain):
|
||||
@ -45,7 +45,7 @@ class ConstitutionalChain(Chain):
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain: LLMChain,
|
||||
critique_prompt: BasePromptTemplate = CRITIQUE_PROMPT,
|
||||
revision_prompt: BasePromptTemplate = REVISION_PROMPT,
|
||||
|
@ -6,8 +6,8 @@ from pydantic import BaseModel, Extra
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_bash.prompt import PROMPT
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.utilities.bash import BashProcess
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ class LLMBashChain(Chain, BaseModel):
|
||||
llm_bash = LLMBashChain(llm=OpenAI())
|
||||
"""
|
||||
|
||||
llm: BaseLLM
|
||||
llm: BaseLanguageModel
|
||||
"""LLM wrapper to use."""
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "answer" #: :meta private:
|
||||
|
@ -12,15 +12,15 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.python import PythonREPL
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class PALChain(Chain, BaseModel):
|
||||
"""Implements Program-Aided Language Models."""
|
||||
|
||||
llm: BaseLLM
|
||||
llm: BaseLanguageModel
|
||||
prompt: BasePromptTemplate
|
||||
stop: str = "\n\n"
|
||||
get_answer_expr: str = "print(solution())"
|
||||
@ -68,7 +68,7 @@ class PALChain(Chain, BaseModel):
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
|
||||
def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain:
|
||||
"""Load PAL from math prompt."""
|
||||
return cls(
|
||||
llm=llm,
|
||||
@ -79,7 +79,9 @@ class PALChain(Chain, BaseModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_colored_object_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:
|
||||
def from_colored_object_prompt(
|
||||
cls, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> PALChain:
|
||||
"""Load PAL from colored object prompt."""
|
||||
return cls(
|
||||
llm=llm,
|
||||
|
@ -19,8 +19,8 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
||||
QUESTION_PROMPT,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
@ -38,7 +38,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
|
||||
question_prompt: BasePromptTemplate = QUESTION_PROMPT,
|
||||
combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
|
||||
@ -65,7 +65,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
chain_type_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
|
@ -13,19 +13,21 @@ from langchain.chains.qa_with_sources import (
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.chains.question_answering import map_rerank_prompt
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
def __call__(
|
||||
self, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_map_rerank_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
|
||||
verbose: bool = False,
|
||||
document_variable_name: str = "context",
|
||||
@ -44,7 +46,7 @@ def _load_map_rerank_chain(
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT,
|
||||
document_variable_name: str = "summaries",
|
||||
@ -62,15 +64,15 @@ def _load_stuff_chain(
|
||||
|
||||
|
||||
def _load_map_reduce_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
||||
combine_document_variable_name: str = "summaries",
|
||||
map_reduce_document_variable_name: str = "context",
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
reduce_llm: Optional[BaseLanguageModel] = None,
|
||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
@ -112,13 +114,13 @@ def _load_map_reduce_chain(
|
||||
|
||||
|
||||
def _load_refine_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
||||
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
||||
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
|
||||
document_variable_name: str = "context_str",
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
refine_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
@ -137,7 +139,7 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_qa_with_sources_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
|
@ -8,8 +8,8 @@ from pydantic import BaseModel, Extra, Field
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.sql_database import SQLDatabase
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
db_chain = SQLDatabaseChain(llm=OpenAI(), database=db)
|
||||
"""
|
||||
|
||||
llm: BaseLLM
|
||||
llm: BaseLanguageModel
|
||||
"""LLM wrapper to use."""
|
||||
database: SQLDatabase = Field(exclude=True)
|
||||
"""SQL Database to connect to."""
|
||||
@ -122,7 +122,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
database: SQLDatabase,
|
||||
query_prompt: BasePromptTemplate = PROMPT,
|
||||
decider_prompt: BasePromptTemplate = DECIDER_PROMPT,
|
||||
|
@ -7,19 +7,21 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
def __call__(
|
||||
self, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_variable_name: str = "text",
|
||||
verbose: Optional[bool] = None,
|
||||
@ -36,14 +38,14 @@ def _load_stuff_chain(
|
||||
|
||||
|
||||
def _load_map_reduce_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
||||
combine_document_variable_name: str = "text",
|
||||
map_reduce_document_variable_name: str = "text",
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
reduce_llm: Optional[BaseLanguageModel] = None,
|
||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
@ -84,12 +86,12 @@ def _load_map_reduce_chain(
|
||||
|
||||
|
||||
def _load_refine_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
|
||||
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
|
||||
document_variable_name: str = "text",
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
refine_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
@ -107,7 +109,7 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_summarize_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
|
@ -11,8 +11,8 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@ -103,7 +103,10 @@ class VectorDBQA(Chain, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLLM, prompt: Optional[PromptTemplate] = None, **kwargs: Any
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> VectorDBQA:
|
||||
"""Initialize from LLM."""
|
||||
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||
@ -122,7 +125,7 @@ class VectorDBQA(Chain, BaseModel):
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
chain_type_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
|
Loading…
Reference in New Issue
Block a user