diff --git a/libs/langchain/langchain/__init__.py b/libs/langchain/langchain/__init__.py index 721769d0e2..5a3f9f2128 100644 --- a/libs/langchain/langchain/__init__.py +++ b/libs/langchain/langchain/__init__.py @@ -1,57 +1,12 @@ # ruff: noqa: E402 """Main entrypoint into package.""" +import warnings from importlib import metadata -from typing import Optional - -from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain -from langchain.chains import ( - ConversationChain, - LLMBashChain, - LLMChain, - LLMCheckerChain, - LLMMathChain, - QAWithSourcesChain, - VectorDBQA, - VectorDBQAWithSourcesChain, -) -from langchain.docstore import InMemoryDocstore, Wikipedia -from langchain.llms import ( - Anthropic, - Banana, - CerebriumAI, - Cohere, - ForefrontAI, - GooseAI, - HuggingFaceHub, - HuggingFaceTextGenInference, - LlamaCpp, - Modal, - OpenAI, - Petals, - PipelineAI, - SagemakerEndpoint, - StochasticAI, - Writer, -) -from langchain.llms.huggingface_pipeline import HuggingFacePipeline -from langchain.prompts import ( - FewShotPromptTemplate, - Prompt, - PromptTemplate, -) -from langchain.schema.cache import BaseCache -from langchain.schema.prompt_template import BasePromptTemplate -from langchain.utilities.arxiv import ArxivAPIWrapper -from langchain.utilities.golden_query import GoldenQueryAPIWrapper -from langchain.utilities.google_search import GoogleSearchAPIWrapper -from langchain.utilities.google_serper import GoogleSerperAPIWrapper -from langchain.utilities.powerbi import PowerBIDataset -from langchain.utilities.searx_search import SearxSearchWrapper -from langchain.utilities.serpapi import SerpAPIWrapper -from langchain.utilities.sql_database import SQLDatabase -from langchain.utilities.wikipedia import WikipediaAPIWrapper -from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper -from langchain.vectorstores import FAISS, ElasticVectorSearch +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from langchain.schema import BaseCache + try: __version__ = metadata.version(__package__) @@ -62,10 +17,200 @@ del metadata # optional, avoids polluting the results of dir(__package__) verbose: bool = False debug: bool = False -llm_cache: Optional[BaseCache] = None +llm_cache: Optional["BaseCache"] = None + + +def __getattr__(name: str) -> Any: + warnings.warn( + f"Importing {name} from langchain root module is no longer supported." + ) + if name == "MRKLChain": + from langchain.agents import MRKLChain + + return MRKLChain + elif name == "ReActChain": + from langchain.agents import ReActChain + + return ReActChain + elif name == "SelfAskWithSearchChain": + from langchain.agents import SelfAskWithSearchChain + + return SelfAskWithSearchChain + elif name == "ConversationChain": + from langchain.chains import ConversationChain + + return ConversationChain + elif name == "LLMBashChain": + from langchain.chains import LLMBashChain + + return LLMBashChain + elif name == "LLMChain": + from langchain.chains import LLMChain + + return LLMChain + elif name == "LLMCheckerChain": + from langchain.chains import LLMCheckerChain + + return LLMCheckerChain + elif name == "LLMMathChain": + from langchain.chains import LLMMathChain + + return LLMMathChain + elif name == "QAWithSourcesChain": + from langchain.chains import QAWithSourcesChain + + return QAWithSourcesChain + elif name == "VectorDBQA": + from langchain.chains import VectorDBQA + + return VectorDBQA + elif name == "VectorDBQAWithSourcesChain": + from langchain.chains import VectorDBQAWithSourcesChain + + return VectorDBQAWithSourcesChain + elif name == "InMemoryDocstore": + from langchain.docstore import InMemoryDocstore + + return InMemoryDocstore + elif name == "Wikipedia": + from langchain.docstore import Wikipedia + + return Wikipedia + elif name == "Anthropic": + from langchain.llms import Anthropic + + return Anthropic + elif name == "Banana": + from langchain.llms import Banana + + return Banana + elif name == "CerebriumAI": + from langchain.llms import CerebriumAI + + return CerebriumAI + elif name == "Cohere": + from langchain.llms import Cohere + + return Cohere + elif name == "ForefrontAI": + from langchain.llms import ForefrontAI + + return ForefrontAI + elif name == "GooseAI": + from langchain.llms import GooseAI + + return GooseAI + elif name == "HuggingFaceHub": + from langchain.llms import HuggingFaceHub + + return HuggingFaceHub + elif name == "HuggingFaceTextGenInference": + from langchain.llms import HuggingFaceTextGenInference + + return HuggingFaceTextGenInference + elif name == "LlamaCpp": + from langchain.llms import LlamaCpp + + return LlamaCpp + elif name == "Modal": + from langchain.llms import Modal + + return Modal + elif name == "OpenAI": + from langchain.llms import OpenAI + + return OpenAI + elif name == "Petals": + from langchain.llms import Petals + + return Petals + elif name == "PipelineAI": + from langchain.llms import PipelineAI + + return PipelineAI + elif name == "SagemakerEndpoint": + from langchain.llms import SagemakerEndpoint + + return SagemakerEndpoint + elif name == "StochasticAI": + from langchain.llms import StochasticAI + + return StochasticAI + elif name == "Writer": + from langchain.llms import Writer + + return Writer + elif name == "HuggingFacePipeline": + from langchain.llms.huggingface_pipeline import HuggingFacePipeline + + return HuggingFacePipeline + elif name == "FewShotPromptTemplate": + from langchain.prompts import FewShotPromptTemplate + + return FewShotPromptTemplate + elif name == "Prompt": + from langchain.prompts import Prompt + + return Prompt + elif name == "PromptTemplate": + from langchain.prompts import PromptTemplate + + return PromptTemplate + elif name == "BasePromptTemplate": + from langchain.schema.prompt_template import BasePromptTemplate + + return BasePromptTemplate + elif name == "ArxivAPIWrapper": + from langchain.utilities import ArxivAPIWrapper + + return ArxivAPIWrapper + elif name == "GoldenQueryAPIWrapper": + from langchain.utilities import GoldenQueryAPIWrapper + + return GoldenQueryAPIWrapper + elif name == "GoogleSearchAPIWrapper": + from langchain.utilities import GoogleSearchAPIWrapper + + return GoogleSearchAPIWrapper + elif name == "GoogleSerperAPIWrapper": + from langchain.utilities import GoogleSerperAPIWrapper + + return GoogleSerperAPIWrapper + elif name == "PowerBIDataset": + from langchain.utilities import PowerBIDataset + + return PowerBIDataset + elif name == "SearxSearchWrapper": + from langchain.utilities import SearxSearchWrapper + + return SearxSearchWrapper + elif name == "WikipediaAPIWrapper": + from langchain.utilities import WikipediaAPIWrapper + + return WikipediaAPIWrapper + elif name == "WolframAlphaAPIWrapper": + from langchain.utilities import WolframAlphaAPIWrapper + + return WolframAlphaAPIWrapper + elif name == "SQLDatabase": + from langchain.utilities import SQLDatabase + + return SQLDatabase + elif name == "FAISS": + from langchain.vectorstores import FAISS + + return FAISS + elif name == "ElasticVectorSearch": + from langchain.vectorstores import ElasticVectorSearch + + return ElasticVectorSearch + # For backwards compatibility + elif name == "SerpAPIChain": + from langchain.utilities import SerpAPIWrapper -# For backwards compatibility -SerpAPIChain = SerpAPIWrapper + return SerpAPIWrapper + else: + raise AttributeError(f"Could not find: {name}") __all__ = [ diff --git a/libs/langchain/langchain/output_parsers/fix.py b/libs/langchain/langchain/output_parsers/fix.py index 84b4b762f3..bf44a06cd8 100644 --- a/libs/langchain/langchain/output_parsers/fix.py +++ b/libs/langchain/langchain/output_parsers/fix.py @@ -1,12 +1,14 @@ from __future__ import annotations -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar -from langchain.chains.llm import LLMChain from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException from langchain.schema.language_model import BaseLanguageModel +if TYPE_CHECKING: + from langchain.chains.llm import LLMChain + T = TypeVar("T") @@ -37,6 +39,8 @@ class OutputFixingParser(BaseOutputParser[T]): Returns: OutputFixingParser """ + from langchain.chains.llm import LLMChain + chain = LLMChain(llm=llm, prompt=prompt) return cls(parser=parser, retry_chain=chain) diff --git a/libs/langchain/langchain/output_parsers/retry.py b/libs/langchain/langchain/output_parsers/retry.py index c9b7337701..633423aca8 100644 --- a/libs/langchain/langchain/output_parsers/retry.py +++ b/libs/langchain/langchain/output_parsers/retry.py @@ -1,8 +1,7 @@ from __future__ import annotations -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar -from langchain.chains.llm import LLMChain from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( BaseOutputParser, @@ -12,6 +11,9 @@ from langchain.schema import ( ) from langchain.schema.language_model import BaseLanguageModel +if TYPE_CHECKING: + from langchain.chains.llm import LLMChain + NAIVE_COMPLETION_RETRY = """Prompt: {prompt} Completion: @@ -56,6 +58,8 @@ class RetryOutputParser(BaseOutputParser[T]): parser: BaseOutputParser[T], prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT, ) -> RetryOutputParser[T]: + from langchain.chains.llm import LLMChain + chain = LLMChain(llm=llm, prompt=prompt) return cls(parser=parser, retry_chain=chain) @@ -142,6 +146,8 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]): Returns: A RetryWithErrorOutputParser. """ + from langchain.chains.llm import LLMChain + chain = LLMChain(llm=llm, prompt=prompt) return cls(parser=parser, retry_chain=chain) diff --git a/libs/langchain/langchain/tools/vectorstore/tool.py b/libs/langchain/langchain/tools/vectorstore/tool.py index 02a6da1c7a..a0507964e7 100644 --- a/libs/langchain/langchain/tools/vectorstore/tool.py +++ b/libs/langchain/langchain/tools/vectorstore/tool.py @@ -4,7 +4,6 @@ import json from typing import Any, Dict, Optional from langchain.callbacks.manager import CallbackManagerForToolRun -from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain from langchain.llms.openai import OpenAI from langchain.pydantic_v1 import BaseModel, Field from langchain.schema.language_model import BaseLanguageModel @@ -48,6 +47,8 @@ class VectorStoreQATool(BaseVectorStoreTool, BaseTool): run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" + from langchain.chains.retrieval_qa.base import RetrievalQA + chain = RetrievalQA.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() ) @@ -78,6 +79,11 @@ class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool): run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" + + from langchain.chains.qa_with_sources.retrieval import ( + RetrievalQAWithSourcesChain, + ) + chain = RetrievalQAWithSourcesChain.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() )