diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 3bb5f917..0dcbe853 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Dict, List, Optional -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, Field, root_validator from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain @@ -18,7 +18,7 @@ class APIChain(Chain, BaseModel): api_request_chain: LLMChain api_answer_chain: LLMChain - requests_wrapper: RequestsWrapper + requests_wrapper: RequestsWrapper = Field(exclude=True) api_docs: str question_key: str = "question" #: :meta private: output_key: str = "output" #: :meta private: @@ -102,3 +102,7 @@ class APIChain(Chain, BaseModel): api_docs=api_docs, **kwargs, ) + + @property + def _chain_type(self) -> str: + return "api_chain" diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index 0addc9dc..1e4fade7 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -168,3 +168,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): extra_return_dict = {} output, _ = self.combine_document_chain.combine_docs(result_docs, **kwargs) return output, extra_return_dict + + @property + def _chain_type(self) -> str: + return "map_reduce_documents_chain" diff --git a/langchain/chains/combine_documents/map_rerank.py b/langchain/chains/combine_documents/map_rerank.py index d1c8092c..f97fc032 100644 --- a/langchain/chains/combine_documents/map_rerank.py +++ b/langchain/chains/combine_documents/map_rerank.py @@ -111,3 +111,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel): if self.return_intermediate_steps: extra_info["intermediate_steps"] = results return output[self.answer_key], extra_info + + @property + def _chain_type(self) -> str: + return "map_rerank_documents_chain" diff --git a/langchain/chains/combine_documents/refine.py b/langchain/chains/combine_documents/refine.py index 61d714eb..57d3d025 100644 --- a/langchain/chains/combine_documents/refine.py +++ b/langchain/chains/combine_documents/refine.py @@ -113,3 +113,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel): else: extra_return_dict = {} return res, extra_return_dict + + @property + def _chain_type(self) -> str: + return "refine_documents_chain" diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index 67bdfa75..d5cfb993 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -83,3 +83,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel): inputs = self._get_inputs(docs, **kwargs) # Call predict on the LLM. return self.llm_chain.predict(**inputs), {} + + @property + def _chain_type(self) -> str: + return "stuff_documents_chain" diff --git a/langchain/chains/hyde/base.py b/langchain/chains/hyde/base.py index fd043bba..29ee31de 100644 --- a/langchain/chains/hyde/base.py +++ b/langchain/chains/hyde/base.py @@ -69,3 +69,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel): prompt = PROMPT_MAP[prompt_key] llm_chain = LLMChain(llm=llm, prompt=prompt) return cls(base_embeddings=base_embeddings, llm_chain=llm_chain) + + @property + def _chain_type(self) -> str: + return "hyde_chain" diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 9cc657ea..5a0d88eb 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -73,3 +73,7 @@ class LLMBashChain(Chain, BaseModel): else: raise ValueError(f"unknown format from LLM: {t}") return {self.output_key: output} + + @property + def _chain_type(self) -> str: + return "llm_bash_chain" diff --git a/langchain/chains/llm_checker/base.py b/langchain/chains/llm_checker/base.py index e2f606e5..cd2f0eec 100644 --- a/langchain/chains/llm_checker/base.py +++ b/langchain/chains/llm_checker/base.py @@ -97,3 +97,7 @@ class LLMCheckerChain(Chain, BaseModel): ) output = question_to_checked_assertions_chain({"question": question}) return {self.output_key: output["revised_statement"]} + + @property + def _chain_type(self) -> str: + return "llm_checker_chain" diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index c169ade3..24115236 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -68,3 +68,7 @@ class LLMMathChain(Chain, BaseModel): else: raise ValueError(f"unknown format from LLM: {t}") return {self.output_key: answer} + + @property + def _chain_type(self) -> str: + return "llm_math_chain" diff --git a/langchain/chains/llm_requests.py b/langchain/chains/llm_requests.py index ed0efa21..f7b0dc4e 100644 --- a/langchain/chains/llm_requests.py +++ b/langchain/chains/llm_requests.py @@ -18,7 +18,9 @@ class LLMRequestsChain(Chain, BaseModel): """Chain that hits a URL and then uses an LLM to parse results.""" llm_chain: LLMChain - requests_wrapper: RequestsWrapper = Field(default_factory=RequestsWrapper) + requests_wrapper: RequestsWrapper = Field( + default_factory=RequestsWrapper, exclude=True + ) text_length: int = 8000 requests_key: str = "requests_result" #: :meta private: input_key: str = "url" #: :meta private: @@ -71,3 +73,7 @@ class LLMRequestsChain(Chain, BaseModel): other_keys[self.requests_key] = soup.get_text()[: self.text_length] result = self.llm_chain.predict(**other_keys) return {self.output_key: result} + + @property + def _chain_type(self) -> str: + return "llm_requests_chain" diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index 60a2e16b..10095d02 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -3,20 +3,35 @@ import json import os import tempfile from pathlib import Path -from typing import Union +from typing import Any, Union import requests import yaml +from langchain.chains.api.base import APIChain from langchain.chains.base import Chain +from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain +from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain +from langchain.chains.combine_documents.refine import RefineDocumentsChain +from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain +from langchain.chains.llm_bash.base import LLMBashChain +from langchain.chains.llm_checker.base import LLMCheckerChain +from langchain.chains.llm_math.base import LLMMathChain +from langchain.chains.llm_requests import LLMRequestsChain +from langchain.chains.pal.base import PALChain +from langchain.chains.qa_with_sources.base import QAWithSourcesChain +from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain +from langchain.chains.sql_database.base import SQLDatabaseChain +from langchain.chains.vector_db_qa.base import VectorDBQA from langchain.llms.loading import load_llm, load_llm_from_config from langchain.prompts.loading import load_prompt, load_prompt_from_config URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/" -def _load_llm_chain(config: dict) -> LLMChain: +def _load_llm_chain(config: dict, **kwargs: Any) -> LLMChain: """Load LLM chain from config dict.""" if "llm" in config: llm_config = config.pop("llm") @@ -24,7 +39,7 @@ def _load_llm_chain(config: dict) -> LLMChain: elif "llm_path" in config: llm = load_llm(config.pop("llm_path")) else: - raise ValueError("One of `llm` or `llm_config` must be present.") + raise ValueError("One of `llm` or `llm_path` must be present.") if "prompt" in config: prompt_config = config.pop("prompt") @@ -37,32 +52,403 @@ def _load_llm_chain(config: dict) -> LLMChain: return LLMChain(llm=llm, prompt=prompt, **config) -type_to_loader_dict = {"llm_chain": _load_llm_chain} +def _load_hyde_chain(config: dict, **kwargs: Any) -> HypotheticalDocumentEmbedder: + """Load hypothetical document embedder chain from config dict.""" + if "llm_chain" in config: + llm_chain_config = config.pop("llm_chain") + llm_chain = load_chain_from_config(llm_chain_config) + elif "llm_chain_path" in config: + llm_chain = load_chain(config.pop("llm_chain_path")) + else: + raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") + if "embeddings" in kwargs: + embeddings = kwargs.pop("embeddings") + else: + raise ValueError("`embeddings` must be present.") + return HypotheticalDocumentEmbedder( + llm_chain=llm_chain, base_embeddings=embeddings, **config + ) + + +def _load_stuff_documents_chain(config: dict, **kwargs: Any) -> StuffDocumentsChain: + if "llm_chain" in config: + llm_chain_config = config.pop("llm_chain") + llm_chain = load_chain_from_config(llm_chain_config) + elif "llm_chain_path" in config: + llm_chain = load_chain(config.pop("llm_chain_path")) + else: + raise ValueError("One of `llm_chain` or `llm_chain_config` must be present.") + + if not isinstance(llm_chain, LLMChain): + raise ValueError(f"Expected LLMChain, got {llm_chain}") + + if "document_prompt" in config: + prompt_config = config.pop("document_prompt") + document_prompt = load_prompt_from_config(prompt_config) + elif "document_prompt_path" in config: + document_prompt = load_prompt(config.pop("document_prompt_path")) + else: + raise ValueError( + "One of `document_prompt` or `document_prompt_path` must be present." + ) + + return StuffDocumentsChain( + llm_chain=llm_chain, document_prompt=document_prompt, **config + ) + + +def _load_map_reduce_documents_chain( + config: dict, **kwargs: Any +) -> MapReduceDocumentsChain: + if "llm_chain" in config: + llm_chain_config = config.pop("llm_chain") + llm_chain = load_chain_from_config(llm_chain_config) + elif "llm_chain_path" in config: + llm_chain = load_chain(config.pop("llm_chain_path")) + else: + raise ValueError("One of `llm_chain` or `llm_chain_config` must be present.") + + if not isinstance(llm_chain, LLMChain): + raise ValueError(f"Expected LLMChain, got {llm_chain}") + + if "combine_document_chain" in config: + combine_document_chain_config = config.pop("combine_document_chain") + combine_document_chain = load_chain_from_config(combine_document_chain_config) + elif "combine_document_chain_path" in config: + combine_document_chain = load_chain(config.pop("combine_document_chain_path")) + else: + raise ValueError( + "One of `combine_document_chain` or " + "`combine_document_chain_path` must be present." + ) + if "collapse_document_chain" in config: + collapse_document_chain_config = config.pop("collapse_document_chain") + if collapse_document_chain_config is None: + collapse_document_chain = None + else: + collapse_document_chain = load_chain_from_config( + collapse_document_chain_config + ) + elif "collapse_document_chain_path" in config: + collapse_document_chain = load_chain(config.pop("collapse_document_chain_path")) + return MapReduceDocumentsChain( + llm_chain=llm_chain, + combine_document_chain=combine_document_chain, + collapse_document_chain=collapse_document_chain, + **config, + ) + + +def _load_llm_bash_chain(config: dict, **kwargs: Any) -> LLMBashChain: + if "llm" in config: + llm_config = config.pop("llm") + llm = load_llm_from_config(llm_config) + elif "llm_path" in config: + llm = load_llm(config.pop("llm_path")) + else: + raise ValueError("One of `llm` or `llm_path` must be present.") + if "prompt" in config: + prompt_config = config.pop("prompt") + prompt = load_prompt_from_config(prompt_config) + elif "prompt_path" in config: + prompt = load_prompt(config.pop("prompt_path")) + return LLMBashChain(llm=llm, prompt=prompt, **config) + + +def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain: + if "llm" in config: + llm_config = config.pop("llm") + llm = load_llm_from_config(llm_config) + elif "llm_path" in config: + llm = load_llm(config.pop("llm_path")) + else: + raise ValueError("One of `llm` or `llm_path` must be present.") + if "create_draft_answer_prompt" in config: + create_draft_answer_prompt_config = config.pop("create_draft_answer_prompt") + create_draft_answer_prompt = load_prompt_from_config( + create_draft_answer_prompt_config + ) + elif "create_draft_answer_prompt_path" in config: + create_draft_answer_prompt = load_prompt( + config.pop("create_draft_answer_prompt_path") + ) + if "list_assertions_prompt" in config: + list_assertions_prompt_config = config.pop("list_assertions_prompt") + list_assertions_prompt = load_prompt_from_config(list_assertions_prompt_config) + elif "list_assertions_prompt_path" in config: + list_assertions_prompt = load_prompt(config.pop("list_assertions_prompt_path")) + if "check_assertions_prompt" in config: + check_assertions_prompt_config = config.pop("check_assertions_prompt") + check_assertions_prompt = load_prompt_from_config( + check_assertions_prompt_config + ) + elif "check_assertions_prompt_path" in config: + check_assertions_prompt = load_prompt( + config.pop("check_assertions_prompt_path") + ) + if "revised_answer_prompt" in config: + revised_answer_prompt_config = config.pop("revised_answer_prompt") + revised_answer_prompt = load_prompt_from_config(revised_answer_prompt_config) + elif "revised_answer_prompt_path" in config: + revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path")) + return LLMCheckerChain( + llm=llm, + create_draft_answer_prompt=create_draft_answer_prompt, + list_assertions_prompt=list_assertions_prompt, + check_assertions_prompt=check_assertions_prompt, + revised_answer_prompt=revised_answer_prompt, + **config, + ) + + +def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain: + if "llm" in config: + llm_config = config.pop("llm") + llm = load_llm_from_config(llm_config) + elif "llm_path" in config: + llm = load_llm(config.pop("llm_path")) + else: + raise ValueError("One of `llm` or `llm_path` must be present.") + if "prompt" in config: + prompt_config = config.pop("prompt") + prompt = load_prompt_from_config(prompt_config) + elif "prompt_path" in config: + prompt = load_prompt(config.pop("prompt_path")) + return LLMMathChain(llm=llm, prompt=prompt, **config) + + +def _load_map_rerank_documents_chain( + config: dict, **kwargs: Any +) -> MapRerankDocumentsChain: + if "llm_chain" in config: + llm_chain_config = config.pop("llm_chain") + llm_chain = load_chain_from_config(llm_chain_config) + elif "llm_chain_path" in config: + llm_chain = load_chain(config.pop("llm_chain_path")) + else: + raise ValueError("One of `llm_chain` or `llm_chain_config` must be present.") + return MapRerankDocumentsChain(llm_chain=llm_chain, **config) + + +def _load_pal_chain(config: dict, **kwargs: Any) -> PALChain: + if "llm" in config: + llm_config = config.pop("llm") + llm = load_llm_from_config(llm_config) + elif "llm_path" in config: + llm = load_llm(config.pop("llm_path")) + else: + raise ValueError("One of `llm` or `llm_path` must be present.") + if "prompt" in config: + prompt_config = config.pop("prompt") + prompt = load_prompt_from_config(prompt_config) + elif "prompt_path" in config: + prompt = load_prompt(config.pop("prompt_path")) + else: + raise ValueError("One of `prompt` or `prompt_path` must be present.") + return PALChain(llm=llm, prompt=prompt, **config) + + +def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocumentsChain: + if "initial_llm_chain" in config: + initial_llm_chain_config = config.pop("initial_llm_chain") + initial_llm_chain = load_chain_from_config(initial_llm_chain_config) + elif "initial_llm_chain_path" in config: + initial_llm_chain = load_chain(config.pop("initial_llm_chain_path")) + else: + raise ValueError( + "One of `initial_llm_chain` or `initial_llm_chain_config` must be present." + ) + if "refine_llm_chain" in config: + refine_llm_chain_config = config.pop("refine_llm_chain") + refine_llm_chain = load_chain_from_config(refine_llm_chain_config) + elif "refine_llm_chain_path" in config: + refine_llm_chain = load_chain(config.pop("refine_llm_chain_path")) + else: + raise ValueError( + "One of `refine_llm_chain` or `refine_llm_chain_config` must be present." + ) + if "document_prompt" in config: + prompt_config = config.pop("document_prompt") + document_prompt = load_prompt_from_config(prompt_config) + elif "document_prompt_path" in config: + document_prompt = load_prompt(config.pop("document_prompt_path")) + return RefineDocumentsChain( + initial_llm_chain=initial_llm_chain, + refine_llm_chain=refine_llm_chain, + document_prompt=document_prompt, + **config, + ) + + +def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesChain: + if "combine_documents_chain" in config: + combine_documents_chain_config = config.pop("combine_documents_chain") + combine_documents_chain = load_chain_from_config(combine_documents_chain_config) + elif "combine_documents_chain_path" in config: + combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + else: + raise ValueError( + "One of `combine_documents_chain` or " + "`combine_documents_chain_path` must be present." + ) + return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config) + + +def _load_sql_database_chain(config: dict, **kwargs: Any) -> SQLDatabaseChain: + if "database" in kwargs: + database = kwargs.pop("database") + else: + raise ValueError("`database` must be present.") + if "llm" in config: + llm_config = config.pop("llm") + llm = load_llm_from_config(llm_config) + elif "llm_path" in config: + llm = load_llm(config.pop("llm_path")) + else: + raise ValueError("One of `llm` or `llm_path` must be present.") + if "prompt" in config: + prompt_config = config.pop("prompt") + prompt = load_prompt_from_config(prompt_config) + return SQLDatabaseChain(database=database, llm=llm, prompt=prompt, **config) + + +def _load_vector_db_qa_with_sources_chain( + config: dict, **kwargs: Any +) -> VectorDBQAWithSourcesChain: + if "vectorstore" in kwargs: + vectorstore = kwargs.pop("vectorstore") + else: + raise ValueError("`vectorstore` must be present.") + if "combine_documents_chain" in config: + combine_documents_chain_config = config.pop("combine_documents_chain") + combine_documents_chain = load_chain_from_config(combine_documents_chain_config) + elif "combine_documents_chain_path" in config: + combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + else: + raise ValueError( + "One of `combine_documents_chain` or " + "`combine_documents_chain_path` must be present." + ) + return VectorDBQAWithSourcesChain( + combine_documents_chain=combine_documents_chain, + vectorstore=vectorstore, + **config, + ) + + +def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA: + if "vectorstore" in kwargs: + vectorstore = kwargs.pop("vectorstore") + else: + raise ValueError("`vectorstore` must be present.") + if "combine_documents_chain" in config: + combine_documents_chain_config = config.pop("combine_documents_chain") + combine_documents_chain = load_chain_from_config(combine_documents_chain_config) + elif "combine_documents_chain_path" in config: + combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + else: + raise ValueError( + "One of `combine_documents_chain` or " + "`combine_documents_chain_path` must be present." + ) + return VectorDBQA( + combine_documents_chain=combine_documents_chain, + vectorstore=vectorstore, + **config, + ) + + +def _load_api_chain(config: dict, **kwargs: Any) -> APIChain: + if "api_request_chain" in config: + api_request_chain_config = config.pop("api_request_chain") + api_request_chain = load_chain_from_config(api_request_chain_config) + elif "api_request_chain_path" in config: + api_request_chain = load_chain(config.pop("api_request_chain_path")) + else: + raise ValueError( + "One of `api_request_chain` or `api_request_chain_path` must be present." + ) + if "api_answer_chain" in config: + api_answer_chain_config = config.pop("api_answer_chain") + api_answer_chain = load_chain_from_config(api_answer_chain_config) + elif "api_answer_chain_path" in config: + api_answer_chain = load_chain(config.pop("api_answer_chain_path")) + else: + raise ValueError( + "One of `api_answer_chain` or `api_answer_chain_path` must be present." + ) + if "requests_wrapper" in kwargs: + requests_wrapper = kwargs.pop("requests_wrapper") + else: + raise ValueError("`requests_wrapper` must be present.") + return APIChain( + api_request_chain=api_request_chain, + api_answer_chain=api_answer_chain, + requests_wrapper=requests_wrapper, + **config, + ) + + +def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain: + if "llm_chain" in config: + llm_chain_config = config.pop("llm_chain") + llm_chain = load_chain_from_config(llm_chain_config) + elif "llm_chain_path" in config: + llm_chain = load_chain(config.pop("llm_chain_path")) + else: + raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") + if "requests_wrapper" in kwargs: + requests_wrapper = kwargs.pop("requests_wrapper") + return LLMRequestsChain( + llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config + ) + else: + return LLMRequestsChain(llm_chain=llm_chain, **config) + + +type_to_loader_dict = { + "api_chain": _load_api_chain, + "hyde_chain": _load_hyde_chain, + "llm_chain": _load_llm_chain, + "llm_bash_chain": _load_llm_bash_chain, + "llm_checker_chain": _load_llm_checker_chain, + "llm_math_chain": _load_llm_math_chain, + "llm_requests_chain": _load_llm_requests_chain, + "pal_chain": _load_pal_chain, + "qa_with_sources_chain": _load_qa_with_sources_chain, + "stuff_documents_chain": _load_stuff_documents_chain, + "map_reduce_documents_chain": _load_map_reduce_documents_chain, + "map_rerank_documents_chain": _load_map_rerank_documents_chain, + "refine_documents_chain": _load_refine_documents_chain, + "sql_database_chain": _load_sql_database_chain, + "vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain, + "vector_db_qa": _load_vector_db_qa, +} -def load_chain_from_config(config: dict) -> Chain: +def load_chain_from_config(config: dict, **kwargs: Any) -> Chain: """Load chain from Config Dict.""" if "_type" not in config: - raise ValueError("Must specify an chain Type in config") + raise ValueError("Must specify a chain Type in config") config_type = config.pop("_type") if config_type not in type_to_loader_dict: raise ValueError(f"Loading {config_type} chain not supported") chain_loader = type_to_loader_dict[config_type] - return chain_loader(config) + return chain_loader(config, **kwargs) -def load_chain(path: Union[str, Path]) -> Chain: +def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain: """Unified method for loading a chain from LangChainHub or local fs.""" if isinstance(path, str) and path.startswith("lc://chains"): path = os.path.relpath(path, "lc://chains/") - return _load_from_hub(path) + return _load_from_hub(path, **kwargs) else: - return _load_chain_from_file(path) + return _load_chain_from_file(path, **kwargs) -def _load_chain_from_file(file: Union[str, Path]) -> Chain: +def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain: """Load chain from file.""" # Convert file to Path object. if isinstance(file, str): @@ -79,10 +465,10 @@ def _load_chain_from_file(file: Union[str, Path]) -> Chain: else: raise ValueError("File type must be json or yaml") # Load the chain from the config now. - return load_chain_from_config(config) + return load_chain_from_config(config, **kwargs) -def _load_from_hub(path: str) -> Chain: +def _load_from_hub(path: str, **kwargs: Any) -> Chain: """Load chain from hub.""" suffix = path.split(".")[-1] if suffix not in {"json", "yaml"}: @@ -95,4 +481,4 @@ def _load_from_hub(path: str) -> Chain: file = tmpdirname + "/chain." + suffix with open(file, "wb") as f: f.write(r.content) - return _load_chain_from_file(file) + return _load_chain_from_file(file, **kwargs) diff --git a/langchain/chains/natbot/base.py b/langchain/chains/natbot/base.py index 6e75946f..d688c6b7 100644 --- a/langchain/chains/natbot/base.py +++ b/langchain/chains/natbot/base.py @@ -94,3 +94,7 @@ class NatBotChain(Chain, BaseModel): self.input_browser_content_key: browser_content, } return self(_inputs)[self.output_key] + + @property + def _chain_type(self) -> str: + return "nat_bot_chain" diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 3b16ed86..ccecd9ec 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -79,3 +79,7 @@ class PALChain(Chain, BaseModel): get_answer_expr="print(answer)", **kwargs, ) + + @property + def _chain_type(self) -> str: + return "pal_chain" diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 628db71b..d566e812 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -126,3 +126,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]: return inputs.pop(self.input_docs_key) + + @property + def _chain_type(self) -> str: + return "qa_with_sources_chain" diff --git a/langchain/chains/qa_with_sources/vector_db.py b/langchain/chains/qa_with_sources/vector_db.py index 6da7ff33..8a567dfa 100644 --- a/langchain/chains/qa_with_sources/vector_db.py +++ b/langchain/chains/qa_with_sources/vector_db.py @@ -13,7 +13,7 @@ from langchain.vectorstores.base import VectorStore class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): """Question-answering with sources over a vector database.""" - vectorstore: VectorStore + vectorstore: VectorStore = Field(exclude=True) """Vector Database to connect to.""" k: int = 4 """Number of results to return from store""" @@ -50,3 +50,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): question, k=self.k, **self.search_kwargs ) return self._reduce_tokens_below_limit(docs) + + @property + def _chain_type(self) -> str: + return "vector_db_qa_with_sources_chain" diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 10377800..dd9bbe84 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Dict, List -from pydantic import BaseModel, Extra +from pydantic import BaseModel, Extra, Field from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -26,7 +26,7 @@ class SQLDatabaseChain(Chain, BaseModel): llm: BaseLLM """LLM wrapper to use.""" - database: SQLDatabase + database: SQLDatabase = Field(exclude=True) """SQL Database to connect to.""" prompt: BasePromptTemplate = PROMPT """Prompt to use to translate natural language to SQL.""" @@ -84,6 +84,10 @@ class SQLDatabaseChain(Chain, BaseModel): self.callback_manager.on_text(final_result, color="green", verbose=self.verbose) return {self.output_key: final_result} + @property + def _chain_type(self) -> str: + return "sql_database_chain" + class SQLDatabaseSequentialChain(Chain, BaseModel): """Chain for querying SQL database that is a sequential chain. @@ -153,3 +157,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel): "table_names_to_use": table_names_to_use, } return self.sql_chain(new_inputs, return_only_outputs=True) + + @property + def _chain_type(self) -> str: + return "sql_database_sequential_chain" diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index 69375d56..73bd4c57 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -29,7 +29,7 @@ class VectorDBQA(Chain, BaseModel): """ - vectorstore: VectorStore + vectorstore: VectorStore = Field(exclude=True) """Vector Database to connect to.""" k: int = 4 """Number of documents to query for.""" @@ -138,3 +138,8 @@ class VectorDBQA(Chain, BaseModel): return {self.output_key: answer, "source_documents": docs} else: return {self.output_key: answer} + + @property + def _chain_type(self) -> str: + """Return the chain type.""" + return "vector_db_qa"