Replace remaining usage of basellm with baselangmodel (#3981)

fix_agent_callbacks
Nuno Campos 1 year ago committed by GitHub
parent f291fd7eed
commit f3ec6d2449
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,11 +3,14 @@ from typing import Any, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
from langchain.llms.base import BaseLLM
from langchain.base_language import BaseLanguageModel
def create_csv_agent(
llm: BaseLLM, path: str, pandas_kwargs: Optional[dict] = None, **kwargs: Any
llm: BaseLanguageModel,
path: str,
pandas_kwargs: Optional[dict] = None,
**kwargs: Any
) -> AgentExecutor:
"""Create csv agent by loading to a dataframe and using pandas agent."""
import pandas as pd

@ -6,13 +6,13 @@ from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
def create_json_agent(
llm: BaseLLM,
llm: BaseLanguageModel,
toolkit: JsonToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = JSON_PREFIX,

@ -4,8 +4,8 @@
from typing import Any, Optional
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
from langchain.llms.base import BaseLLM
from langchain.requests import Requests
from langchain.tools.openapi.utils.api_models import APIOperation
from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec
@ -32,7 +32,7 @@ class NLATool(Tool):
@classmethod
def from_llm_and_method(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
path: str,
method: str,
spec: OpenAPISpec,

@ -7,7 +7,7 @@ from pydantic import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.agents.agent_toolkits.nla.tool import NLATool
from langchain.llms.base import BaseLLM
from langchain.base_language import BaseLanguageModel
from langchain.requests import Requests
from langchain.tools.base import BaseTool
from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec
@ -26,7 +26,7 @@ class NLAToolkit(BaseToolkit):
@staticmethod
def _get_http_operation_tools(
llm: BaseLLM,
llm: BaseLanguageModel,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
@ -53,7 +53,7 @@ class NLAToolkit(BaseToolkit):
@classmethod
def from_llm_and_spec(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
spec: OpenAPISpec,
requests: Optional[Requests] = None,
verbose: bool = False,
@ -68,7 +68,7 @@ class NLAToolkit(BaseToolkit):
@classmethod
def from_llm_and_url(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
open_api_url: str,
requests: Optional[Requests] = None,
verbose: bool = False,
@ -83,7 +83,7 @@ class NLAToolkit(BaseToolkit):
@classmethod
def from_llm_and_ai_plugin(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
ai_plugin: AIPlugin,
requests: Optional[Requests] = None,
verbose: bool = False,
@ -103,7 +103,7 @@ class NLAToolkit(BaseToolkit):
@classmethod
def from_llm_and_ai_plugin_url(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
ai_plugin_url: str,
requests: Optional[Requests] = None,
verbose: bool = False,

@ -9,13 +9,13 @@ from langchain.agents.agent_toolkits.openapi.prompt import (
from langchain.agents.agent_toolkits.openapi.toolkit import OpenAPIToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
def create_openapi_agent(
llm: BaseLLM,
llm: BaseLanguageModel,
toolkit: OpenAPIToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = OPENAPI_PREFIX,

@ -9,7 +9,7 @@ from langchain.agents.agent_toolkits.json.base import create_json_agent
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.openapi.prompt import DESCRIPTION
from langchain.agents.tools import Tool
from langchain.llms.base import BaseLLM
from langchain.base_language import BaseLanguageModel
from langchain.requests import TextRequestsWrapper
from langchain.tools import BaseTool
from langchain.tools.json.tool import JsonSpec
@ -57,7 +57,7 @@ class OpenAPIToolkit(BaseToolkit):
@classmethod
def from_llm(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
json_spec: JsonSpec,
requests_wrapper: TextRequestsWrapper,
**kwargs: Any,

@ -4,14 +4,14 @@ from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX, SUFFIX
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.tools.python.tool import PythonAstREPLTool
def create_pandas_dataframe_agent(
llm: BaseLLM,
llm: BaseLanguageModel,
df: Any,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,

@ -9,14 +9,14 @@ from langchain.agents.agent_toolkits.powerbi.prompt import (
from langchain.agents.agent_toolkits.powerbi.toolkit import PowerBIToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.utilities.powerbi import PowerBIDataset
def create_pbi_agent(
llm: BaseLLM,
llm: BaseLanguageModel,
toolkit: Optional[PowerBIToolkit],
powerbi: Optional[PowerBIDataset] = None,
callback_manager: Optional[BaseCallbackManager] = None,

@ -5,14 +5,14 @@ from typing import Any, Dict, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.python.prompt import PREFIX
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.tools.python.tool import PythonREPLTool
def create_python_agent(
llm: BaseLLM,
llm: BaseLanguageModel,
tool: PythonREPLTool,
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = False,

@ -6,13 +6,13 @@ from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
def create_sql_agent(
llm: BaseLLM,
llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,

@ -8,13 +8,13 @@ from langchain.agents.agent_toolkits.vectorstore.toolkit import (
VectorStoreToolkit,
)
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
def create_vectorstore_agent(
llm: BaseLLM,
llm: BaseLanguageModel,
toolkit: VectorStoreToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
@ -42,7 +42,7 @@ def create_vectorstore_agent(
def create_vectorstore_router_agent(
llm: BaseLLM,
llm: BaseLanguageModel,
toolkit: VectorStoreRouterToolkit,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = ROUTER_PREFIX,

@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional, Callable, Tuple
from mypy_extensions import Arg, KwArg
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.api import news_docs, open_meteo_docs, podcast_docs, tmdb_docs
from langchain.chains.api.base import APIChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.pal.base import PALChain
from langchain.llms.base import BaseLLM
from langchain.requests import TextRequestsWrapper
from langchain.tools.arxiv.tool import ArxivQueryRun
from langchain.tools.base import BaseTool
@ -32,8 +32,6 @@ from langchain.tools.shell.tool import ShellTool
from langchain.tools.wikipedia.tool import WikipediaQueryRun
from langchain.tools.wolfram_alpha.tool import WolframAlphaQueryRun
from langchain.utilities import ArxivAPIWrapper
from langchain.utilities.apify import ApifyWrapper
from langchain.utilities.bash import BashProcess
from langchain.utilities.bing_search import BingSearchAPIWrapper
from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from langchain.utilities.google_search import GoogleSearchAPIWrapper
@ -85,7 +83,7 @@ _BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = {
}
def _get_pal_math(llm: BaseLLM) -> BaseTool:
def _get_pal_math(llm: BaseLanguageModel) -> BaseTool:
return Tool(
name="PAL-MATH",
description="A language model that is really good at solving complex word math problems. Input should be a fully worded hard word math problem.",
@ -93,7 +91,7 @@ def _get_pal_math(llm: BaseLLM) -> BaseTool:
)
def _get_pal_colored_objects(llm: BaseLLM) -> BaseTool:
def _get_pal_colored_objects(llm: BaseLanguageModel) -> BaseTool:
return Tool(
name="PAL-COLOR-OBJ",
description="A language model that is really good at reasoning about position and the color attributes of objects. Input should be a fully worded hard reasoning problem. Make sure to include all information about the objects AND the final question you want to answer.",
@ -101,7 +99,7 @@ def _get_pal_colored_objects(llm: BaseLLM) -> BaseTool:
)
def _get_llm_math(llm: BaseLLM) -> BaseTool:
def _get_llm_math(llm: BaseLanguageModel) -> BaseTool:
return Tool(
name="Calculator",
description="Useful for when you need to answer questions about math.",
@ -110,7 +108,7 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool:
)
def _get_open_meteo_api(llm: BaseLLM) -> BaseTool:
def _get_open_meteo_api(llm: BaseLanguageModel) -> BaseTool:
chain = APIChain.from_llm_and_api_docs(llm, open_meteo_docs.OPEN_METEO_DOCS)
return Tool(
name="Open Meteo API",
@ -119,7 +117,7 @@ def _get_open_meteo_api(llm: BaseLLM) -> BaseTool:
)
_LLM_TOOLS: Dict[str, Callable[[BaseLLM], BaseTool]] = {
_LLM_TOOLS: Dict[str, Callable[[BaseLanguageModel], BaseTool]] = {
"pal-math": _get_pal_math,
"pal-colored-objects": _get_pal_colored_objects,
"llm-math": _get_llm_math,
@ -127,7 +125,7 @@ _LLM_TOOLS: Dict[str, Callable[[BaseLLM], BaseTool]] = {
}
def _get_news_api(llm: BaseLLM, **kwargs: Any) -> BaseTool:
def _get_news_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
news_api_key = kwargs["news_api_key"]
chain = APIChain.from_llm_and_api_docs(
llm, news_docs.NEWS_DOCS, headers={"X-Api-Key": news_api_key}
@ -139,7 +137,7 @@ def _get_news_api(llm: BaseLLM, **kwargs: Any) -> BaseTool:
)
def _get_tmdb_api(llm: BaseLLM, **kwargs: Any) -> BaseTool:
def _get_tmdb_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
tmdb_bearer_token = kwargs["tmdb_bearer_token"]
chain = APIChain.from_llm_and_api_docs(
llm,
@ -153,7 +151,7 @@ def _get_tmdb_api(llm: BaseLLM, **kwargs: Any) -> BaseTool:
)
def _get_podcast_api(llm: BaseLLM, **kwargs: Any) -> BaseTool:
def _get_podcast_api(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
listen_api_key = kwargs["listen_api_key"]
chain = APIChain.from_llm_and_api_docs(
llm,
@ -238,7 +236,8 @@ def _get_scenexplain(**kwargs: Any) -> BaseTool:
_EXTRA_LLM_TOOLS: Dict[
str, Tuple[Callable[[Arg(BaseLLM, "llm"), KwArg(Any)], BaseTool], List[str]]
str,
Tuple[Callable[[Arg(BaseLanguageModel, "llm"), KwArg(Any)], BaseTool], List[str]],
] = {
"news-api": (_get_news_api, ["news_api_key"]),
"tmdb-api": (_get_tmdb_api, ["tmdb_bearer_token"]),
@ -277,7 +276,7 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
def load_tools(
tool_names: List[str],
llm: Optional[BaseLLM] = None,
llm: Optional[BaseLanguageModel] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> List[BaseTool]:

@ -8,15 +8,15 @@ import yaml
from langchain.agents.agent import BaseSingleActionAgent
from langchain.agents.tools import Tool
from langchain.agents.types import AGENT_TO_CLASS
from langchain.base_language import BaseLanguageModel
from langchain.chains.loading import load_chain, load_chain_from_config
from langchain.llms.base import BaseLLM
from langchain.utilities.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
def _load_agent_from_tools(
config: dict, llm: BaseLLM, tools: List[Tool], **kwargs: Any
config: dict, llm: BaseLanguageModel, tools: List[Tool], **kwargs: Any
) -> BaseSingleActionAgent:
config_type = config.pop("_type")
if config_type not in AGENT_TO_CLASS:
@ -29,7 +29,7 @@ def _load_agent_from_tools(
def load_agent_from_config(
config: dict,
llm: Optional[BaseLLM] = None,
llm: Optional[BaseLanguageModel] = None,
tools: Optional[List[Tool]] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:

@ -10,9 +10,9 @@ from langchain.agents.react.textworld_prompt import TEXTWORLD_PROMPT
from langchain.agents.react.wiki_prompt import WIKI_PROMPT
from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool
@ -141,7 +141,7 @@ class ReActChain(AgentExecutor):
react = ReAct(llm=OpenAI())
"""
def __init__(self, llm: BaseLLM, docstore: Docstore, **kwargs: Any):
def __init__(self, llm: BaseLanguageModel, docstore: Docstore, **kwargs: Any):
"""Initialize with the LLM and a docstore."""
docstore_explorer = DocstoreExplorer(docstore)
tools = [

@ -9,7 +9,7 @@ from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputPar
from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.llms.base import BaseLLM
from langchain.base_language import BaseLanguageModel
from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
@ -71,7 +71,7 @@ class SelfAskWithSearchChain(AgentExecutor):
def __init__(
self,
llm: BaseLLM,
llm: BaseLanguageModel,
search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper],
**kwargs: Any,
):

@ -7,12 +7,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, cast
from pydantic import BaseModel, Field
from requests import Response
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains.api.openapi.requests_chain import APIRequesterChain
from langchain.chains.api.openapi.response_chain import APIResponderChain
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.requests import Requests
from langchain.tools.openapi.utils.api_models import APIOperation
@ -168,7 +168,7 @@ class OpenAPIEndpointChain(Chain, BaseModel):
spec_url: str,
path: str,
method: str,
llm: BaseLLM,
llm: BaseLanguageModel,
requests: Optional[Requests] = None,
return_intermediate_steps: bool = False,
**kwargs: Any
@ -188,7 +188,7 @@ class OpenAPIEndpointChain(Chain, BaseModel):
def from_api_operation(
cls,
operation: APIOperation,
llm: BaseLLM,
llm: BaseLanguageModel,
requests: Optional[Requests] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,

@ -4,9 +4,9 @@ import json
import re
from typing import Any
from langchain.base_language import BaseLanguageModel
from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser
@ -38,7 +38,7 @@ class APIRequesterChain(LLMChain):
@classmethod
def from_llm_and_typescript(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
typescript_definition: str,
verbose: bool = True,
**kwargs: Any,

@ -4,9 +4,9 @@ import json
import re
from typing import Any
from langchain.base_language import BaseLanguageModel
from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser
@ -36,7 +36,9 @@ class APIResponderChain(LLMChain):
"""Get the response parser."""
@classmethod
def from_llm(cls, llm: BaseLLM, verbose: bool = True, **kwargs: Any) -> LLMChain:
def from_llm(
cls, llm: BaseLanguageModel, verbose: bool = True, **kwargs: Any
) -> LLMChain:
"""Get the response parser."""
output_parser = APIResponderOutputParser()
prompt = PromptTemplate(

@ -5,12 +5,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
@ -43,7 +43,7 @@ class GraphQAChain(Chain):
@classmethod
def from_llm(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
qa_prompt: BasePromptTemplate = PROMPT,
entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT,
**kwargs: Any,

@ -9,12 +9,12 @@ from typing import Any, Dict, List, Optional
import numpy as np
from pydantic import Extra
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP
from langchain.chains.llm import LLMChain
from langchain.embeddings.base import Embeddings
from langchain.llms.base import BaseLLM
class HypotheticalDocumentEmbedder(Chain, Embeddings):
@ -70,7 +70,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
@classmethod
def from_llm(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
base_embeddings: Embeddings,
prompt_key: str,
**kwargs: Any,

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
@ -16,12 +17,11 @@ from langchain.chains.llm_checker.prompt import (
REVISED_ANSWER_PROMPT,
)
from langchain.chains.sequential import SequentialChain
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
def _load_question_to_checked_assertions_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
create_draft_answer_prompt: PromptTemplate,
list_assertions_prompt: PromptTemplate,
check_assertions_prompt: PromptTemplate,
@ -75,7 +75,7 @@ class LLMCheckerChain(Chain):
question_to_checked_assertions_chain: SequentialChain
llm: Optional[BaseLLM] = None
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
"""[Deprecated]"""
@ -158,7 +158,7 @@ class LLMCheckerChain(Chain):
@classmethod
def from_llm(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT,
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT,
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT,

@ -8,11 +8,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sequential import SequentialChain
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
PROMPTS_DIR = Path(__file__).parent / "prompts"
@ -32,7 +32,7 @@ ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(
def _load_sequential_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
create_assertions_prompt: PromptTemplate,
check_assertions_prompt: PromptTemplate,
revised_summary_prompt: PromptTemplate,
@ -85,7 +85,7 @@ class LLMSummarizationCheckerChain(Chain):
"""
sequential_chain: SequentialChain
llm: Optional[BaseLLM] = None
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT
@ -180,7 +180,7 @@ class LLMSummarizationCheckerChain(Chain):
@classmethod
def from_llm(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT,
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT,
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT,

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -16,7 +17,6 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.text_splitter import TextSplitter
@ -34,7 +34,7 @@ class MapReduceChain(Chain):
@classmethod
def from_params(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
text_splitter: TextSplitter,
callbacks: Callbacks = None,

@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.natbot.prompt import PROMPT
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI
@ -27,7 +27,7 @@ class NatBotChain(Chain):
llm_chain: LLMChain
objective: str
"""Objective that NatBot is tasked with completing."""
llm: Optional[BaseLLM] = None
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
input_url_key: str = "url" #: :meta private:
input_browser_content_key: str = "browser_content" #: :meta private:
@ -59,7 +59,9 @@ class NatBotChain(Chain):
return cls.from_llm(llm, objective, **kwargs)
@classmethod
def from_llm(cls, llm: BaseLLM, objective: str, **kwargs: Any) -> NatBotChain:
def from_llm(
cls, llm: BaseLanguageModel, objective: str, **kwargs: Any
) -> NatBotChain:
"""Load from LLM."""
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
return cls(llm_chain=llm_chain, objective=objective, **kwargs)

@ -4,9 +4,9 @@ from __future__ import annotations
from typing import Any, List
from langchain import PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.eval_prompt import CONTEXT_PROMPT, COT_PROMPT, PROMPT
from langchain.llms.base import BaseLLM
class QAEvalChain(LLMChain):
@ -14,12 +14,12 @@ class QAEvalChain(LLMChain):
@classmethod
def from_llm(
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
cls, llm: BaseLanguageModel, prompt: PromptTemplate = PROMPT, **kwargs: Any
) -> QAEvalChain:
"""Load QA Eval Chain from LLM.
Args:
llm (BaseLLM): the base language model to use.
llm (BaseLanguageModel): the base language model to use.
prompt (PromptTemplate): A prompt template containing the input_variables:
'input', 'answer' and 'result' that will be used as the prompt
@ -74,12 +74,15 @@ class ContextQAEvalChain(LLMChain):
@classmethod
def from_llm(
cls, llm: BaseLLM, prompt: PromptTemplate = CONTEXT_PROMPT, **kwargs: Any
cls,
llm: BaseLanguageModel,
prompt: PromptTemplate = CONTEXT_PROMPT,
**kwargs: Any,
) -> ContextQAEvalChain:
"""Load QA Eval Chain from LLM.
Args:
llm (BaseLLM): the base language model to use.
llm (BaseLanguageModel): the base language model to use.
prompt (PromptTemplate): A prompt template containing the input_variables:
'query', 'context' and 'result' that will be used as the prompt
@ -120,7 +123,7 @@ class CotQAEvalChain(ContextQAEvalChain):
@classmethod
def from_llm(
cls, llm: BaseLLM, prompt: PromptTemplate = COT_PROMPT, **kwargs: Any
cls, llm: BaseLanguageModel, prompt: PromptTemplate = COT_PROMPT, **kwargs: Any
) -> CotQAEvalChain:
cls._validate_input_vars(prompt)
return cls(llm=llm, prompt=prompt, **kwargs)

@ -3,15 +3,15 @@ from __future__ import annotations
from typing import Any
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.evaluation.qa.generate_prompt import PROMPT
from langchain.llms.base import BaseLLM
class QAGenerateChain(LLMChain):
"""LLM Chain specifically for generating examples for question answering."""
@classmethod
def from_llm(cls, llm: BaseLLM, **kwargs: Any) -> QAGenerateChain:
def from_llm(cls, llm: BaseLanguageModel, **kwargs: Any) -> QAGenerateChain:
"""Load QA Generate Chain from LLM."""
return cls(llm=llm, prompt=PROMPT, **kwargs)

@ -1,8 +1,8 @@
"""Utility functions for working with prompts."""
from typing import List
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
@ -10,7 +10,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example."
def generate_example(
examples: List[dict], llm: BaseLLM, prompt_template: PromptTemplate
examples: List[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate
) -> str:
"""Return another example given a list of examples for a prompt."""
prompt = FewShotPromptTemplate(

@ -3,18 +3,18 @@ from typing import Optional, Type
from pydantic import BaseModel
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.graphs.networkx_graph import NetworkxEntityGraph, parse_triples
from langchain.indexes.prompts.knowledge_triplet_extraction import (
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)
from langchain.llms.base import BaseLLM
class GraphIndexCreator(BaseModel):
"""Functionality to create graph index."""
llm: Optional[BaseLLM] = None
llm: Optional[BaseLanguageModel] = None
graph_type: Type[NetworkxEntityGraph] = NetworkxEntityGraph
def from_text(self, text: str) -> NetworkxEntityGraph:

@ -2,12 +2,12 @@ from typing import Any, List, Optional, Type
from pydantic import BaseModel, Extra, Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.base import Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.llms.openai import OpenAI
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@ -30,7 +30,9 @@ class VectorStoreIndexWrapper(BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
def query(self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
def query(
self, question: str, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
) -> str:
"""Query the vectorstore."""
llm = llm or OpenAI(temperature=0)
chain = RetrievalQA.from_chain_type(
@ -39,7 +41,7 @@ class VectorStoreIndexWrapper(BaseModel):
return chain.run(question)
def query_with_sources(
self, question: str, llm: Optional[BaseLLM] = None, **kwargs: Any
self, question: str, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
) -> dict:
"""Query the vectorstore and get back sources."""
llm = llm or OpenAI(temperature=0)

Loading…
Cancel
Save