Replace remaining usage of basellm with baselangmodel (#3981)

This commit is contained in:
Nuno Campos 2023-05-03 05:52:29 +01:00 committed by GitHub
parent f291fd7eed
commit f3ec6d2449
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 97 additions and 86 deletions

View File

@ -3,11 +3,14 @@ from typing import Any, Optional
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
from langchain.llms.base import BaseLLM from langchain.base_language import BaseLanguageModel
def create_csv_agent( def create_csv_agent(
llm: BaseLLM, path: str, pandas_kwargs: Optional[dict] = None, **kwargs: Any llm: BaseLanguageModel,
path: str,
pandas_kwargs: Optional[dict] = None,
**kwargs: Any
) -> AgentExecutor: ) -> AgentExecutor:
"""Create csv agent by loading to a dataframe and using pandas agent.""" """Create csv agent by loading to a dataframe and using pandas agent."""
import pandas as pd import pandas as pd

View File

@ -6,13 +6,13 @@ from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
def create_json_agent( def create_json_agent(
llm: BaseLLM, llm: BaseLanguageModel,
toolkit: JsonToolkit, toolkit: JsonToolkit,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = JSON_PREFIX, prefix: str = JSON_PREFIX,

View File

@ -4,8 +4,8 @@
from typing import Any, Optional from typing import Any, Optional
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.chains.api.openapi.chain import OpenAPIEndpointChain from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
from langchain.llms.base import BaseLLM
from langchain.requests import Requests from langchain.requests import Requests
from langchain.tools.openapi.utils.api_models import APIOperation from langchain.tools.openapi.utils.api_models import APIOperation
from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec
@ -32,7 +32,7 @@ class NLATool(Tool):
@classmethod @classmethod
def from_llm_and_method( def from_llm_and_method(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
path: str, path: str,
method: str, method: str,
spec: OpenAPISpec, spec: OpenAPISpec,

View File

@ -7,7 +7,7 @@ from pydantic import Field
from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.agents.agent_toolkits.nla.tool import NLATool from langchain.agents.agent_toolkits.nla.tool import NLATool
from langchain.llms.base import BaseLLM from langchain.base_language import BaseLanguageModel
from langchain.requests import Requests from langchain.requests import Requests
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec
@ -26,7 +26,7 @@ class NLAToolkit(BaseToolkit):
@staticmethod @staticmethod
def _get_http_operation_tools( def _get_http_operation_tools(
llm: BaseLLM, llm: BaseLanguageModel,
spec: OpenAPISpec, spec: OpenAPISpec,
requests: Optional[Requests] = None, requests: Optional[Requests] = None,
verbose: bool = False, verbose: bool = False,
@ -53,7 +53,7 @@ class NLAToolkit(BaseToolkit):
@classmethod @classmethod
def from_llm_and_spec( def from_llm_and_spec(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
spec: OpenAPISpec, spec: OpenAPISpec,
requests: Optional[Requests] = None, requests: Optional[Requests] = None,
verbose: bool = False, verbose: bool = False,
@ -68,7 +68,7 @@ class NLAToolkit(BaseToolkit):
@classmethod @classmethod
def from_llm_and_url( def from_llm_and_url(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
open_api_url: str, open_api_url: str,
requests: Optional[Requests] = None, requests: Optional[Requests] = None,
verbose: bool = False, verbose: bool = False,
@ -83,7 +83,7 @@ class NLAToolkit(BaseToolkit):
@classmethod @classmethod
def from_llm_and_ai_plugin( def from_llm_and_ai_plugin(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
ai_plugin: AIPlugin, ai_plugin: AIPlugin,
requests: Optional[Requests] = None, requests: Optional[Requests] = None,
verbose: bool = False, verbose: bool = False,
@ -103,7 +103,7 @@ class NLAToolkit(BaseToolkit):
@classmethod @classmethod
def from_llm_and_ai_plugin_url( def from_llm_and_ai_plugin_url(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
ai_plugin_url: str, ai_plugin_url: str,
requests: Optional[Requests] = None, requests: Optional[Requests] = None,
verbose: bool = False, verbose: bool = False,

View File

@ -9,13 +9,13 @@ from langchain.agents.agent_toolkits.openapi.prompt import (
from langchain.agents.agent_toolkits.openapi.toolkit import OpenAPIToolkit from langchain.agents.agent_toolkits.openapi.toolkit import OpenAPIToolkit
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
def create_openapi_agent( def create_openapi_agent(
llm: BaseLLM, llm: BaseLanguageModel,
toolkit: OpenAPIToolkit, toolkit: OpenAPIToolkit,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = OPENAPI_PREFIX, prefix: str = OPENAPI_PREFIX,

View File

@ -9,7 +9,7 @@ from langchain.agents.agent_toolkits.json.base import create_json_agent
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.openapi.prompt import DESCRIPTION from langchain.agents.agent_toolkits.openapi.prompt import DESCRIPTION
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.llms.base import BaseLLM from langchain.base_language import BaseLanguageModel
from langchain.requests import TextRequestsWrapper from langchain.requests import TextRequestsWrapper
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.tools.json.tool import JsonSpec from langchain.tools.json.tool import JsonSpec
@ -57,7 +57,7 @@ class OpenAPIToolkit(BaseToolkit):
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
json_spec: JsonSpec, json_spec: JsonSpec,
requests_wrapper: TextRequestsWrapper, requests_wrapper: TextRequestsWrapper,
**kwargs: Any, **kwargs: Any,

View File

@ -4,14 +4,14 @@ from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX, SUFFIX from langchain.agents.agent_toolkits.pandas.prompt import PREFIX, SUFFIX
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.tools.python.tool import PythonAstREPLTool from langchain.tools.python.tool import PythonAstREPLTool
def create_pandas_dataframe_agent( def create_pandas_dataframe_agent(
llm: BaseLLM, llm: BaseLanguageModel,
df: Any, df: Any,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX, prefix: str = PREFIX,

View File

@ -9,14 +9,14 @@ from langchain.agents.agent_toolkits.powerbi.prompt import (
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.utilities.powerbi import PowerBIDataset from langchain.utilities.powerbi import PowerBIDataset
def create_pbi_agent( def create_pbi_agent(
llm: BaseLLM, llm: BaseLanguageModel,
toolkit: Optional[PowerBIToolkit], toolkit: Optional[PowerBIToolkit],
powerbi: Optional[PowerBIDataset] = None, powerbi: Optional[PowerBIDataset] = None,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,

View File

@ -5,14 +5,14 @@ from typing import Any, Dict, Optional
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.python.prompt import PREFIX from langchain.agents.agent_toolkits.python.prompt import PREFIX
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.tools.python.tool import PythonREPLTool from langchain.tools.python.tool import PythonREPLTool
def create_python_agent( def create_python_agent(
llm: BaseLLM, llm: BaseLanguageModel,
tool: PythonREPLTool, tool: PythonREPLTool,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = False, verbose: bool = False,

View File

@ -6,13 +6,13 @@ from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
def create_sql_agent( def create_sql_agent(
llm: BaseLLM, llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit, toolkit: SQLDatabaseToolkit,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX, prefix: str = SQL_PREFIX,

View File

@ -8,13 +8,13 @@ from langchain.agents.agent_toolkits.vectorstore.toolkit import (
VectorStoreToolkit, VectorStoreToolkit,
) )
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
def create_vectorstore_agent( def create_vectorstore_agent(
llm: BaseLLM, llm: BaseLanguageModel,
toolkit: VectorStoreToolkit, toolkit: VectorStoreToolkit,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX, prefix: str = PREFIX,
@ -42,7 +42,7 @@ def create_vectorstore_agent(
def create_vectorstore_router_agent( def create_vectorstore_router_agent(
llm: BaseLLM, llm: BaseLanguageModel,
toolkit: VectorStoreRouterToolkit, toolkit: VectorStoreRouterToolkit,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = ROUTER_PREFIX, prefix: str = ROUTER_PREFIX,

View File

@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional, Callable, Tuple
from mypy_extensions import Arg, KwArg from mypy_extensions import Arg, KwArg
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
from langchain.chains.api.base import APIChain from langchain.chains.api.base import APIChain
from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.pal.base import PALChain from langchain.chains.pal.base import PALChain
from langchain.llms.base import BaseLLM
from langchain.requests import TextRequestsWrapper from langchain.requests import TextRequestsWrapper
from langchain.tools.arxiv.tool import ArxivQueryRun from langchain.tools.arxiv.tool import ArxivQueryRun
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
@ -32,8 +32,6 @@ from langchain.tools.shell.tool import ShellTool
from langchain.tools.wikipedia.tool import WikipediaQueryRun from langchain.tools.wikipedia.tool import WikipediaQueryRun
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
from langchain.utilities import ArxivAPIWrapper from langchain.utilities import ArxivAPIWrapper
from langchain.utilities.apify import ApifyWrapper
from langchain.utilities.bash import BashProcess
from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from langchain.utilities.google_search import GoogleSearchAPIWrapper from langchain.utilities.google_search import GoogleSearchAPIWrapper
@ -85,7 +83,7 @@ _BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {
} }
def _get_pal_math(llm: BaseLLM) -> BaseTool: def _get_pal_math(llm: BaseLanguageModel) -> BaseTool:
return Tool( return Tool(
name="PAL-MATH", name="PAL-MATH",
description="A language model that is really good at solving complex word math problems. Input should be a fully worded hard word math problem.", description="A language model that is really good at solving complex word math problems. Input should be a fully worded hard word math problem.",
@ -93,7 +91,7 @@ def _get_pal_math(llm: BaseLLM) -> BaseTool:
) )
def _get_pal_colored_objects(llm: BaseLLM) -> BaseTool: def _get_pal_colored_objects(llm: BaseLanguageModel) -> BaseTool:
return Tool( return Tool(
name="PAL-COLOR-OBJ", name="PAL-COLOR-OBJ",
description="A language model that is really good at reasoning about position and the color attributes of objects. Input should be a fully worded hard reasoning problem. Make sure to include all information about the objects AND the final question you want to answer.", description="A language model that is really good at reasoning about position and the color attributes of objects. Input should be a fully worded hard reasoning problem. Make sure to include all information about the objects AND the final question you want to answer.",
@ -101,7 +99,7 @@ def _get_pal_colored_objects(llm: BaseLLM) -> BaseTool:
) )
def _get_llm_math(llm: BaseLLM) -> BaseTool: def _get_llm_math(llm: BaseLanguageModel) -> BaseTool:
return Tool( return Tool(
name="Calculator", name="Calculator",
description="Useful for when you need to answer questions about math.", description="Useful for when you need to answer questions about math.",
@ -110,7 +108,7 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool:
) )
def _get_open_meteo_api(llm: BaseLLM) -> BaseTool: def _get_open_meteo_api(llm: BaseLanguageModel) -> BaseTool:
chain = APIChain.from_llm_and_api_docs(llm, open_meteo_docs.OPEN_METEO_DOCS) chain = APIChain.from_llm_and_api_docs(llm, open_meteo_docs.OPEN_METEO_DOCS)
return Tool( return Tool(
name="Open Meteo API", name="Open Meteo API",
@ -119,7 +117,7 @@ def _get_open_meteo_api(llm: BaseLLM) -> BaseTool:
) )
_LLM_TOOLS: Dict[str, Callable[[BaseLLM], BaseTool]] = { _LLM_TOOLS: Dict[str, Callable[[BaseLanguageModel], BaseTool]] = {
"pal-math": _get_pal_math, "pal-math": _get_pal_math,
"pal-colored-objects": _get_pal_colored_objects, "pal-colored-objects": _get_pal_colored_objects,
"llm-math": _get_llm_math, "llm-math": _get_llm_math,
@ -127,7 +125,7 @@ _LLM_TOOLS: Dict[str, Callable[[BaseLLM], BaseTool]] = {
} }
def _get_news_api(llm: BaseLLM, **kwargs: Any) -> BaseTool: def _get_news_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
news_api_key = kwargs["news_api_key"] news_api_key = kwargs["news_api_key"]
chain = APIChain.from_llm_and_api_docs( chain = APIChain.from_llm_and_api_docs(
llm, news_docs.NEWS_DOCS, headers={"X-Api-Key": news_api_key} llm, news_docs.NEWS_DOCS, headers={"X-Api-Key": news_api_key}
@ -139,7 +137,7 @@ def _get_news_api(llm: BaseLLM, **kwargs: Any) -> BaseTool:
) )
def _get_tmdb_api(llm: BaseLLM, **kwargs: Any) -> BaseTool: def _get_tmdb_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
tmdb_bearer_token = kwargs["tmdb_bearer_token"] tmdb_bearer_token = kwargs["tmdb_bearer_token"]
chain = APIChain.from_llm_and_api_docs( chain = APIChain.from_llm_and_api_docs(
llm, llm,
@ -153,7 +151,7 @@ def _get_tmdb_api(llm: BaseLLM, **kwargs: Any) -> BaseTool:
) )
def _get_podcast_api(llm: BaseLLM, **kwargs: Any) -> BaseTool: def _get_podcast_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
listen_api_key = kwargs["listen_api_key"] listen_api_key = kwargs["listen_api_key"]
chain = APIChain.from_llm_and_api_docs( chain = APIChain.from_llm_and_api_docs(
llm, llm,
@ -238,7 +236,8 @@ def _get_scenexplain(**kwargs: Any) -> BaseTool:
_EXTRA_LLM_TOOLS: Dict[ _EXTRA_LLM_TOOLS: Dict[
str, Tuple[Callable[[Arg(BaseLLM, "llm"), KwArg(Any)], BaseTool], List[str]] str,
Tuple[Callable[[Arg(BaseLanguageModel, "llm"), KwArg(Any)], BaseTool], List[str]],
] = { ] = {
"news-api": (_get_news_api, ["news_api_key"]), "news-api": (_get_news_api, ["news_api_key"]),
"tmdb-api": (_get_tmdb_api, ["tmdb_bearer_token"]), "tmdb-api": (_get_tmdb_api, ["tmdb_bearer_token"]),
@ -277,7 +276,7 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
def load_tools( def load_tools(
tool_names: List[str], tool_names: List[str],
llm: Optional[BaseLLM] = None, llm: Optional[BaseLanguageModel] = None,
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any, **kwargs: Any,
) -> List[BaseTool]: ) -> List[BaseTool]:

View File

@ -8,15 +8,15 @@ import yaml
from langchain.agents.agent import BaseSingleActionAgent from langchain.agents.agent import BaseSingleActionAgent
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.agents.types import AGENT_TO_CLASS from langchain.agents.types import AGENT_TO_CLASS
from langchain.base_language import BaseLanguageModel
from langchain.chains.loading import load_chain, load_chain_from_config from langchain.chains.loading import load_chain, load_chain_from_config
from langchain.llms.base import BaseLLM
from langchain.utilities.loading import try_load_from_hub from langchain.utilities.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/" URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
def _load_agent_from_tools( def _load_agent_from_tools(
config: dict, llm: BaseLLM, tools: List[Tool], **kwargs: Any config: dict, llm: BaseLanguageModel, tools: List[Tool], **kwargs: Any
) -> BaseSingleActionAgent: ) -> BaseSingleActionAgent:
config_type = config.pop("_type") config_type = config.pop("_type")
if config_type not in AGENT_TO_CLASS: if config_type not in AGENT_TO_CLASS:
@ -29,7 +29,7 @@ def _load_agent_from_tools(
def load_agent_from_config( def load_agent_from_config(
config: dict, config: dict,
llm: Optional[BaseLLM] = None, llm: Optional[BaseLanguageModel] = None,
tools: Optional[List[Tool]] = None, tools: Optional[List[Tool]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseSingleActionAgent: ) -> BaseSingleActionAgent:

View File

@ -10,9 +10,9 @@ from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
from langchain.agents.react.wiki_prompt import WIKI_PROMPT from langchain.agents.react.wiki_prompt import WIKI_PROMPT
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel
from langchain.docstore.base import Docstore from langchain.docstore.base import Docstore
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
@ -141,7 +141,7 @@ class ReActChain(AgentExecutor):
react = ReAct(llm=OpenAI()) react = ReAct(llm=OpenAI())
""" """
def __init__(self, llm: BaseLLM, docstore: Docstore, **kwargs: Any): def __init__(self, llm: BaseLanguageModel, docstore: Docstore, **kwargs: Any):
"""Initialize with the LLM and a docstore.""" """Initialize with the LLM and a docstore."""
docstore_explorer = DocstoreExplorer(docstore) docstore_explorer = DocstoreExplorer(docstore)
tools = [ tools = [

View File

@ -9,7 +9,7 @@ from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputPar
from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input from langchain.agents.utils import validate_tools_single_input
from langchain.llms.base import BaseLLM from langchain.base_language import BaseLanguageModel
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.utilities.google_serper import GoogleSerperAPIWrapper from langchain.utilities.google_serper import GoogleSerperAPIWrapper
@ -71,7 +71,7 @@ class SelfAskWithSearchChain(AgentExecutor):
def __init__( def __init__(
self, self,
llm: BaseLLM, llm: BaseLanguageModel,
search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper], search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper],
**kwargs: Any, **kwargs: Any,
): ):

View File

@ -7,12 +7,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from requests import Response from requests import Response
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains.api.openapi.requests_chain import APIRequesterChain from langchain.chains.api.openapi.requests_chain import APIRequesterChain
from langchain.chains.api.openapi.response_chain import APIResponderChain from langchain.chains.api.openapi.response_chain import APIResponderChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.requests import Requests from langchain.requests import Requests
from langchain.tools.openapi.utils.api_models import APIOperation from langchain.tools.openapi.utils.api_models import APIOperation
@ -168,7 +168,7 @@ class OpenAPIEndpointChain(Chain, BaseModel):
spec_url: str, spec_url: str,
path: str, path: str,
method: str, method: str,
llm: BaseLLM, llm: BaseLanguageModel,
requests: Optional[Requests] = None, requests: Optional[Requests] = None,
return_intermediate_steps: bool = False, return_intermediate_steps: bool = False,
**kwargs: Any **kwargs: Any
@ -188,7 +188,7 @@ class OpenAPIEndpointChain(Chain, BaseModel):
def from_api_operation( def from_api_operation(
cls, cls,
operation: APIOperation, operation: APIOperation,
llm: BaseLLM, llm: BaseLanguageModel,
requests: Optional[Requests] = None, requests: Optional[Requests] = None,
verbose: bool = False, verbose: bool = False,
return_intermediate_steps: bool = False, return_intermediate_steps: bool = False,

View File

@ -4,9 +4,9 @@ import json
import re import re
from typing import Any from typing import Any
from langchain.base_language import BaseLanguageModel
from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser from langchain.schema import BaseOutputParser
@ -38,7 +38,7 @@ class APIRequesterChain(LLMChain):
@classmethod @classmethod
def from_llm_and_typescript( def from_llm_and_typescript(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
typescript_definition: str, typescript_definition: str,
verbose: bool = True, verbose: bool = True,
**kwargs: Any, **kwargs: Any,

View File

@ -4,9 +4,9 @@ import json
import re import re
from typing import Any from typing import Any
from langchain.base_language import BaseLanguageModel
from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser from langchain.schema import BaseOutputParser
@ -36,7 +36,9 @@ class APIResponderChain(LLMChain):
"""Get the response parser.""" """Get the response parser."""
@classmethod @classmethod
def from_llm(cls, llm: BaseLLM, verbose: bool = True, **kwargs: Any) -> LLMChain: def from_llm(
cls, llm: BaseLanguageModel, verbose: bool = True, **kwargs: Any
) -> LLMChain:
"""Get the response parser.""" """Get the response parser."""
output_parser = APIResponderOutputParser() output_parser = APIResponderOutputParser()
prompt = PromptTemplate( prompt = PromptTemplate(

View File

@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Field from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
@ -43,7 +43,7 @@ class GraphQAChain(Chain):
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
qa_prompt: BasePromptTemplate = PROMPT, qa_prompt: BasePromptTemplate = PROMPT,
entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT, entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT,
**kwargs: Any, **kwargs: Any,

View File

@ -9,12 +9,12 @@ from typing import Any, Dict, List, Optional
import numpy as np import numpy as np
from pydantic import Extra from pydantic import Extra
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.chains.hyde.prompts import PROMPT_MAP
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.llms.base import BaseLLM
class HypotheticalDocumentEmbedder(Chain, Embeddings): class HypotheticalDocumentEmbedder(Chain, Embeddings):
@ -70,7 +70,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
base_embeddings: Embeddings, base_embeddings: Embeddings,
prompt_key: str, prompt_key: str,
**kwargs: Any, **kwargs: Any,

View File

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator from pydantic import Extra, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -16,12 +17,11 @@ from langchain.chains.llm_checker.prompt import (
REVISED_ANSWER_PROMPT, REVISED_ANSWER_PROMPT,
) )
from langchain.chains.sequential import SequentialChain from langchain.chains.sequential import SequentialChain
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
def _load_question_to_checked_assertions_chain( def _load_question_to_checked_assertions_chain(
llm: BaseLLM, llm: BaseLanguageModel,
create_draft_answer_prompt: PromptTemplate, create_draft_answer_prompt: PromptTemplate,
list_assertions_prompt: PromptTemplate, list_assertions_prompt: PromptTemplate,
check_assertions_prompt: PromptTemplate, check_assertions_prompt: PromptTemplate,
@ -75,7 +75,7 @@ class LLMCheckerChain(Chain):
question_to_checked_assertions_chain: SequentialChain question_to_checked_assertions_chain: SequentialChain
llm: Optional[BaseLLM] = None llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use.""" """[Deprecated] LLM wrapper to use."""
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
"""[Deprecated]""" """[Deprecated]"""
@ -158,7 +158,7 @@ class LLMCheckerChain(Chain):
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT, create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT,
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT, list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT,
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT, check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT,

View File

@ -8,11 +8,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator from pydantic import Extra, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.sequential import SequentialChain from langchain.chains.sequential import SequentialChain
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
PROMPTS_DIR = Path(__file__).parent / "prompts" PROMPTS_DIR = Path(__file__).parent / "prompts"
@ -32,7 +32,7 @@ ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(
def _load_sequential_chain( def _load_sequential_chain(
llm: BaseLLM, llm: BaseLanguageModel,
create_assertions_prompt: PromptTemplate, create_assertions_prompt: PromptTemplate,
check_assertions_prompt: PromptTemplate, check_assertions_prompt: PromptTemplate,
revised_summary_prompt: PromptTemplate, revised_summary_prompt: PromptTemplate,
@ -85,7 +85,7 @@ class LLMSummarizationCheckerChain(Chain):
""" """
sequential_chain: SequentialChain sequential_chain: SequentialChain
llm: Optional[BaseLLM] = None llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use.""" """[Deprecated] LLM wrapper to use."""
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT
@ -180,7 +180,7 @@ class LLMSummarizationCheckerChain(Chain):
@classmethod @classmethod
def from_llm( def from_llm(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT, create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT,
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT, check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT,
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT, revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT,

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra from pydantic import Extra
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -16,7 +17,6 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.text_splitter import TextSplitter from langchain.text_splitter import TextSplitter
@ -34,7 +34,7 @@ class MapReduceChain(Chain):
@classmethod @classmethod
def from_params( def from_params(
cls, cls,
llm: BaseLLM, llm: BaseLanguageModel,
prompt: BasePromptTemplate, prompt: BasePromptTemplate,
text_splitter: TextSplitter, text_splitter: TextSplitter,
callbacks: Callbacks = None, callbacks: Callbacks = None,

View File

@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator from pydantic import Extra, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.natbot.prompt import PROMPT from langchain.chains.natbot.prompt import PROMPT
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
@ -27,7 +27,7 @@ class NatBotChain(Chain):
llm_chain: LLMChain llm_chain: LLMChain
objective: str objective: str
"""Objective that NatBot is tasked with completing.""" """Objective that NatBot is tasked with completing."""
llm: Optional[BaseLLM] = None llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use.""" """[Deprecated] LLM wrapper to use."""
input_url_key: str = "url" #: :meta private: input_url_key: str = "url" #: :meta private:
input_browser_content_key: str = "browser_content" #: :meta private: input_browser_content_key: str = "browser_content" #: :meta private:
@ -59,7 +59,9 @@ class NatBotChain(Chain):
return cls.from_llm(llm, objective, **kwargs) return cls.from_llm(llm, objective, **kwargs)
@classmethod @classmethod
def from_llm(cls, llm: BaseLLM, objective: str, **kwargs: Any) -> NatBotChain: def from_llm(
cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
) -> NatBotChain:
"""Load from LLM.""" """Load from LLM."""
llm_chain = LLMChain(llm=llm, prompt=PROMPT) llm_chain = LLMChain(llm=llm, prompt=PROMPT)
return cls(llm_chain=llm_chain, objective=objective, **kwargs) return cls(llm_chain=llm_chain, objective=objective, **kwargs)

View File

@ -4,9 +4,9 @@ from __future__ import annotations
from typing import Any, List from typing import Any, List
from langchain import PromptTemplate from langchain import PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
from langchain.llms.base import BaseLLM
class QAEvalChain(LLMChain): class QAEvalChain(LLMChain):
@ -14,12 +14,12 @@ class QAEvalChain(LLMChain):
@classmethod @classmethod
def from_llm( def from_llm(
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any
) -> QAEvalChain: ) -> QAEvalChain:
"""Load QA Eval Chain from LLM. """Load QA Eval Chain from LLM.
Args: Args:
llm (BaseLLM): the base language model to use. llm (BaseLanguageModel): the base language model to use.
prompt (PromptTemplate): A prompt template containing the input_variables: prompt (PromptTemplate): A prompt template containing the input_variables:
'input', 'answer' and 'result' that will be used as the prompt 'input', 'answer' and 'result' that will be used as the prompt
@ -74,12 +74,15 @@ class ContextQAEvalChain(LLMChain):
@classmethod @classmethod
def from_llm( def from_llm(
cls, llm: BaseLLM, prompt: PromptTemplate = CONTEXT_PROMPT, **kwargs: Any cls,
llm: BaseLanguageModel,
prompt: PromptTemplate = CONTEXT_PROMPT,
**kwargs: Any,
) -> ContextQAEvalChain: ) -> ContextQAEvalChain:
"""Load QA Eval Chain from LLM. """Load QA Eval Chain from LLM.
Args: Args:
llm (BaseLLM): the base language model to use. llm (BaseLanguageModel): the base language model to use.
prompt (PromptTemplate): A prompt template containing the input_variables: prompt (PromptTemplate): A prompt template containing the input_variables:
'query', 'context' and 'result' that will be used as the prompt 'query', 'context' and 'result' that will be used as the prompt
@ -120,7 +123,7 @@ class CotQAEvalChain(ContextQAEvalChain):
@classmethod @classmethod
def from_llm( def from_llm(
cls, llm: BaseLLM, prompt: PromptTemplate = COT_PROMPT, **kwargs: Any cls, llm: BaseLanguageModel, prompt: PromptTemplate = COT_PROMPT, **kwargs: Any
) -> CotQAEvalChain: ) -> CotQAEvalChain:
cls._validate_input_vars(prompt) cls._validate_input_vars(prompt)
return cls(llm=llm, prompt=prompt, **kwargs) return cls(llm=llm, prompt=prompt, **kwargs)

View File

@ -3,15 +3,15 @@ from __future__ import annotations
from typing import Any from typing import Any
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.generate_prompt import PROMPT from langchain.evaluation.qa.generate_prompt import PROMPT
from langchain.llms.base import BaseLLM
class QAGenerateChain(LLMChain): class QAGenerateChain(LLMChain):
"""LLM Chain specifically for generating examples for question answering.""" """LLM Chain specifically for generating examples for question answering."""
@classmethod @classmethod
def from_llm(cls, llm: BaseLLM, **kwargs: Any) -> QAGenerateChain: def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> QAGenerateChain:
"""Load QA Generate Chain from LLM.""" """Load QA Generate Chain from LLM."""
return cls(llm=llm, prompt=PROMPT, **kwargs) return cls(llm=llm, prompt=PROMPT, **kwargs)

View File

@ -1,8 +1,8 @@
"""Utility functions for working with prompts.""" """Utility functions for working with prompts."""
from typing import List from typing import List
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
@ -10,7 +10,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example."
def generate_example( def generate_example(
examples: List[dict], llm: BaseLLM, prompt_template: PromptTemplate examples: List[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate
) -> str: ) -> str:
"""Return another example given a list of examples for a prompt.""" """Return another example given a list of examples for a prompt."""
prompt = FewShotPromptTemplate( prompt = FewShotPromptTemplate(

View File

@ -3,18 +3,18 @@ from typing import Optional, Type
from pydantic import BaseModel from pydantic import BaseModel
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples
from langchain.indexes.prompts.knowledge_triplet_extraction import ( from langchain.indexes.prompts.knowledge_triplet_extraction import (
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT, KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
) )
from langchain.llms.base import BaseLLM
class GraphIndexCreator(BaseModel): class GraphIndexCreator(BaseModel):
"""Functionality to create graph index.""" """Functionality to create graph index."""
llm: Optional[BaseLLM] = None llm: Optional[BaseLanguageModel] = None
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph
def from_text(self, text: str) -> NetworkxEntityGraph: def from_text(self, text: str) -> NetworkxEntityGraph:

View File

@ -2,12 +2,12 @@ from typing import Any, List, Optional, Type
from pydantic import BaseModel, Extra, Field from pydantic import BaseModel, Extra, Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.retrieval_qa.base import RetrievalQA from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from langchain.schema import Document from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@ -30,7 +30,9 @@ class VectorStoreIndexWrapper(BaseModel):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
def query(self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: def query(
self, question: str, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
) -> str:
"""Query the vectorstore.""" """Query the vectorstore."""
llm = llm or OpenAI(temperature=0) llm = llm or OpenAI(temperature=0)
chain = RetrievalQA.from_chain_type( chain = RetrievalQA.from_chain_type(
@ -39,7 +41,7 @@ class VectorStoreIndexWrapper(BaseModel):
return chain.run(question) return chain.run(question)
def query_with_sources( def query_with_sources(
self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any self, question: str, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
) -> dict: ) -> dict:
"""Query the vectorstore and get back sources.""" """Query the vectorstore and get back sources."""
llm = llm or OpenAI(temperature=0) llm = llm or OpenAI(temperature=0)