conditional imports (#11017)

This commit is contained in:
Harrison Chase 2023-09-25 15:46:32 -07:00 committed by GitHub
parent 0625ab7a9e
commit c87e9fb2ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 219 additions and 58 deletions

View File

@ -1,57 +1,12 @@
# ruff: noqa: E402 # ruff: noqa: E402
"""Main entrypoint into package.""" """Main entrypoint into package."""
import warnings
from importlib import metadata from importlib import metadata
from typing import Optional from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from langchain.schema import BaseCache
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.chains import (
ConversationChain,
LLMBashChain,
LLMChain,
LLMCheckerChain,
LLMMathChain,
QAWithSourcesChain,
VectorDBQA,
VectorDBQAWithSourcesChain,
)
from langchain.docstore import InMemoryDocstore, Wikipedia
from langchain.llms import (
Anthropic,
Banana,
CerebriumAI,
Cohere,
ForefrontAI,
GooseAI,
HuggingFaceHub,
HuggingFaceTextGenInference,
LlamaCpp,
Modal,
OpenAI,
Petals,
PipelineAI,
SagemakerEndpoint,
StochasticAI,
Writer,
)
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.prompts import (
FewShotPromptTemplate,
Prompt,
PromptTemplate,
)
from langchain.schema.cache import BaseCache
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.utilities.arxiv import ArxivAPIWrapper
from langchain.utilities.golden_query import GoldenQueryAPIWrapper
from langchain.utilities.google_search import GoogleSearchAPIWrapper
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
from langchain.utilities.powerbi import PowerBIDataset
from langchain.utilities.searx_search import SearxSearchWrapper
from langchain.utilities.serpapi import SerpAPIWrapper
from langchain.utilities.sql_database import SQLDatabase
from langchain.utilities.wikipedia import WikipediaAPIWrapper
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from langchain.vectorstores import FAISS, ElasticVectorSearch
try: try:
__version__ = metadata.version(__package__) __version__ = metadata.version(__package__)
@ -62,10 +17,200 @@ del metadata # optional, avoids polluting the results of dir(__package__)
verbose: bool = False verbose: bool = False
debug: bool = False debug: bool = False
llm_cache: Optional[BaseCache] = None llm_cache: Optional["BaseCache"] = None
# For backwards compatibility
SerpAPIChain = SerpAPIWrapper def __getattr__(name: str) -> Any:
warnings.warn(
f"Importing {name} from langchain root module is no longer supported."
)
if name == "MRKLChain":
from langchain.agents import MRKLChain
return MRKLChain
elif name == "ReActChain":
from langchain.agents import ReActChain
return ReActChain
elif name == "SelfAskWithSearchChain":
from langchain.agents import SelfAskWithSearchChain
return SelfAskWithSearchChain
elif name == "ConversationChain":
from langchain.chains import ConversationChain
return ConversationChain
elif name == "LLMBashChain":
from langchain.chains import LLMBashChain
return LLMBashChain
elif name == "LLMChain":
from langchain.chains import LLMChain
return LLMChain
elif name == "LLMCheckerChain":
from langchain.chains import LLMCheckerChain
return LLMCheckerChain
elif name == "LLMMathChain":
from langchain.chains import LLMMathChain
return LLMMathChain
elif name == "QAWithSourcesChain":
from langchain.chains import QAWithSourcesChain
return QAWithSourcesChain
elif name == "VectorDBQA":
from langchain.chains import VectorDBQA
return VectorDBQA
elif name == "VectorDBQAWithSourcesChain":
from langchain.chains import VectorDBQAWithSourcesChain
return VectorDBQAWithSourcesChain
elif name == "InMemoryDocstore":
from langchain.docstore import InMemoryDocstore
return InMemoryDocstore
elif name == "Wikipedia":
from langchain.docstore import Wikipedia
return Wikipedia
elif name == "Anthropic":
from langchain.llms import Anthropic
return Anthropic
elif name == "Banana":
from langchain.llms import Banana
return Banana
elif name == "CerebriumAI":
from langchain.llms import CerebriumAI
return CerebriumAI
elif name == "Cohere":
from langchain.llms import Cohere
return Cohere
elif name == "ForefrontAI":
from langchain.llms import ForefrontAI
return ForefrontAI
elif name == "GooseAI":
from langchain.llms import GooseAI
return GooseAI
elif name == "HuggingFaceHub":
from langchain.llms import HuggingFaceHub
return HuggingFaceHub
elif name == "HuggingFaceTextGenInference":
from langchain.llms import HuggingFaceTextGenInference
return HuggingFaceTextGenInference
elif name == "LlamaCpp":
from langchain.llms import LlamaCpp
return LlamaCpp
elif name == "Modal":
from langchain.llms import Modal
return Modal
elif name == "OpenAI":
from langchain.llms import OpenAI
return OpenAI
elif name == "Petals":
from langchain.llms import Petals
return Petals
elif name == "PipelineAI":
from langchain.llms import PipelineAI
return PipelineAI
elif name == "SagemakerEndpoint":
from langchain.llms import SagemakerEndpoint
return SagemakerEndpoint
elif name == "StochasticAI":
from langchain.llms import StochasticAI
return StochasticAI
elif name == "Writer":
from langchain.llms import Writer
return Writer
elif name == "HuggingFacePipeline":
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
return HuggingFacePipeline
elif name == "FewShotPromptTemplate":
from langchain.prompts import FewShotPromptTemplate
return FewShotPromptTemplate
elif name == "Prompt":
from langchain.prompts import Prompt
return Prompt
elif name == "PromptTemplate":
from langchain.prompts import PromptTemplate
return PromptTemplate
elif name == "BasePromptTemplate":
from langchain.schema.prompt_template import BasePromptTemplate
return BasePromptTemplate
elif name == "ArxivAPIWrapper":
from langchain.utilities import ArxivAPIWrapper
return ArxivAPIWrapper
elif name == "GoldenQueryAPIWrapper":
from langchain.utilities import GoldenQueryAPIWrapper
return GoldenQueryAPIWrapper
elif name == "GoogleSearchAPIWrapper":
from langchain.utilities import GoogleSearchAPIWrapper
return GoogleSearchAPIWrapper
elif name == "GoogleSerperAPIWrapper":
from langchain.utilities import GoogleSerperAPIWrapper
return GoogleSerperAPIWrapper
elif name == "PowerBIDataset":
from langchain.utilities import PowerBIDataset
return PowerBIDataset
elif name == "SearxSearchWrapper":
from langchain.utilities import SearxSearchWrapper
return SearxSearchWrapper
elif name == "WikipediaAPIWrapper":
from langchain.utilities import WikipediaAPIWrapper
return WikipediaAPIWrapper
elif name == "WolframAlphaAPIWrapper":
from langchain.utilities import WolframAlphaAPIWrapper
return WolframAlphaAPIWrapper
elif name == "SQLDatabase":
from langchain.utilities import SQLDatabase
return SQLDatabase
elif name == "FAISS":
from langchain.vectorstores import FAISS
return FAISS
elif name == "ElasticVectorSearch":
from langchain.vectorstores import ElasticVectorSearch
return ElasticVectorSearch
# For backwards compatibility
elif name == "SerpAPIChain":
from langchain.utilities import SerpAPIWrapper
return SerpAPIWrapper
else:
raise AttributeError(f"Could not find: {name}")
__all__ = [ __all__ = [

View File

@ -1,12 +1,14 @@
from __future__ import annotations from __future__ import annotations
from typing import TypeVar from typing import TYPE_CHECKING, TypeVar
from langchain.chains.llm import LLMChain
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
if TYPE_CHECKING:
from langchain.chains.llm import LLMChain
T = TypeVar("T") T = TypeVar("T")
@ -37,6 +39,8 @@ class OutputFixingParser(BaseOutputParser[T]):
Returns: Returns:
OutputFixingParser OutputFixingParser
""" """
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain) return cls(parser=parser, retry_chain=chain)

View File

@ -1,8 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TypeVar from typing import TYPE_CHECKING, TypeVar
from langchain.chains.llm import LLMChain
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import ( from langchain.schema import (
BaseOutputParser, BaseOutputParser,
@ -12,6 +11,9 @@ from langchain.schema import (
) )
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
if TYPE_CHECKING:
from langchain.chains.llm import LLMChain
NAIVE_COMPLETION_RETRY = """Prompt: NAIVE_COMPLETION_RETRY = """Prompt:
{prompt} {prompt}
Completion: Completion:
@ -56,6 +58,8 @@ class RetryOutputParser(BaseOutputParser[T]):
parser: BaseOutputParser[T], parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT, prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
) -> RetryOutputParser[T]: ) -> RetryOutputParser[T]:
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain) return cls(parser=parser, retry_chain=chain)
@ -142,6 +146,8 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
Returns: Returns:
A RetryWithErrorOutputParser. A RetryWithErrorOutputParser.
""" """
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain) return cls(parser=parser, retry_chain=chain)

View File

@ -4,7 +4,6 @@ import json
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from langchain.callbacks.manager import CallbackManagerForToolRun from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from langchain.pydantic_v1 import BaseModel, Field from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
@ -48,6 +47,8 @@ class VectorStoreQATool(BaseVectorStoreTool, BaseTool):
run_manager: Optional[CallbackManagerForToolRun] = None, run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str: ) -> str:
"""Use the tool.""" """Use the tool."""
from langchain.chains.retrieval_qa.base import RetrievalQA
chain = RetrievalQA.from_chain_type( chain = RetrievalQA.from_chain_type(
self.llm, retriever=self.vectorstore.as_retriever() self.llm, retriever=self.vectorstore.as_retriever()
) )
@ -78,6 +79,11 @@ class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool):
run_manager: Optional[CallbackManagerForToolRun] = None, run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str: ) -> str:
"""Use the tool.""" """Use the tool."""
from langchain.chains.qa_with_sources.retrieval import (
RetrievalQAWithSourcesChain,
)
chain = RetrievalQAWithSourcesChain.from_chain_type( chain = RetrievalQAWithSourcesChain.from_chain_type(
self.llm, retriever=self.vectorstore.as_retriever() self.llm, retriever=self.vectorstore.as_retriever()
) )