Serialize all the chains! (#761)

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
ankush/async-llmchain
Samantha Whitmore 1 year ago committed by GitHub
parent e2a7fed890
commit be7de427ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, root_validator from pydantic import BaseModel, Field, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -18,7 +18,7 @@ class APIChain(Chain, BaseModel):
api_request_chain: LLMChain api_request_chain: LLMChain
api_answer_chain: LLMChain api_answer_chain: LLMChain
requests_wrapper: RequestsWrapper requests_wrapper: RequestsWrapper = Field(exclude=True)
api_docs: str api_docs: str
question_key: str = "question" #: :meta private: question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private: output_key: str = "output" #: :meta private:
@ -102,3 +102,7 @@ class APIChain(Chain, BaseModel):
api_docs=api_docs, api_docs=api_docs,
**kwargs, **kwargs,
) )
@property
def _chain_type(self) -> str:
return "api_chain"

@ -168,3 +168,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
extra_return_dict = {} extra_return_dict = {}
output, _ = self.combine_document_chain.combine_docs(result_docs, **kwargs) output, _ = self.combine_document_chain.combine_docs(result_docs, **kwargs)
return output, extra_return_dict return output, extra_return_dict
@property
def _chain_type(self) -> str:
return "map_reduce_documents_chain"

@ -111,3 +111,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel):
if self.return_intermediate_steps: if self.return_intermediate_steps:
extra_info["intermediate_steps"] = results extra_info["intermediate_steps"] = results
return output[self.answer_key], extra_info return output[self.answer_key], extra_info
@property
def _chain_type(self) -> str:
return "map_rerank_documents_chain"

@ -113,3 +113,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain, BaseModel):
else: else:
extra_return_dict = {} extra_return_dict = {}
return res, extra_return_dict return res, extra_return_dict
@property
def _chain_type(self) -> str:
return "refine_documents_chain"

@ -83,3 +83,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
inputs = self._get_inputs(docs, **kwargs) inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM. # Call predict on the LLM.
return self.llm_chain.predict(**inputs), {} return self.llm_chain.predict(**inputs), {}
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"

@ -69,3 +69,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel):
prompt = PROMPT_MAP[prompt_key] prompt = PROMPT_MAP[prompt_key]
llm_chain = LLMChain(llm=llm, prompt=prompt) llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain) return cls(base_embeddings=base_embeddings, llm_chain=llm_chain)
@property
def _chain_type(self) -> str:
return "hyde_chain"

@ -73,3 +73,7 @@ class LLMBashChain(Chain, BaseModel):
else: else:
raise ValueError(f"unknown format from LLM: {t}") raise ValueError(f"unknown format from LLM: {t}")
return {self.output_key: output} return {self.output_key: output}
@property
def _chain_type(self) -> str:
return "llm_bash_chain"

@ -97,3 +97,7 @@ class LLMCheckerChain(Chain, BaseModel):
) )
output = question_to_checked_assertions_chain({"question": question}) output = question_to_checked_assertions_chain({"question": question})
return {self.output_key: output["revised_statement"]} return {self.output_key: output["revised_statement"]}
@property
def _chain_type(self) -> str:
return "llm_checker_chain"

@ -68,3 +68,7 @@ class LLMMathChain(Chain, BaseModel):
else: else:
raise ValueError(f"unknown format from LLM: {t}") raise ValueError(f"unknown format from LLM: {t}")
return {self.output_key: answer} return {self.output_key: answer}
@property
def _chain_type(self) -> str:
return "llm_math_chain"

@ -18,7 +18,9 @@ class LLMRequestsChain(Chain, BaseModel):
"""Chain that hits a URL and then uses an LLM to parse results.""" """Chain that hits a URL and then uses an LLM to parse results."""
llm_chain: LLMChain llm_chain: LLMChain
requests_wrapper: RequestsWrapper = Field(default_factory=RequestsWrapper) requests_wrapper: RequestsWrapper = Field(
default_factory=RequestsWrapper, exclude=True
)
text_length: int = 8000 text_length: int = 8000
requests_key: str = "requests_result" #: :meta private: requests_key: str = "requests_result" #: :meta private:
input_key: str = "url" #: :meta private: input_key: str = "url" #: :meta private:
@ -71,3 +73,7 @@ class LLMRequestsChain(Chain, BaseModel):
other_keys[self.requests_key] = soup.get_text()[: self.text_length] other_keys[self.requests_key] = soup.get_text()[: self.text_length]
result = self.llm_chain.predict(**other_keys) result = self.llm_chain.predict(**other_keys)
return {self.output_key: result} return {self.output_key: result}
@property
def _chain_type(self) -> str:
return "llm_requests_chain"

@ -3,20 +3,35 @@ import json
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Union from typing import Any, Union
import requests import requests
import yaml import yaml
from langchain.chains.api.base import APIChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_checker.base import LLMCheckerChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.pal.base import PALChain
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.chains.vector_db_qa.base import VectorDBQA
from langchain.llms.loading import load_llm, load_llm_from_config from langchain.llms.loading import load_llm, load_llm_from_config
from langchain.prompts.loading import load_prompt, load_prompt_from_config from langchain.prompts.loading import load_prompt, load_prompt_from_config
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/" URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"
def _load_llm_chain(config: dict) -> LLMChain: def _load_llm_chain(config: dict, **kwargs: Any) -> LLMChain:
"""Load LLM chain from config dict.""" """Load LLM chain from config dict."""
if "llm" in config: if "llm" in config:
llm_config = config.pop("llm") llm_config = config.pop("llm")
@ -24,7 +39,7 @@ def _load_llm_chain(config: dict) -> LLMChain:
elif "llm_path" in config: elif "llm_path" in config:
llm = load_llm(config.pop("llm_path")) llm = load_llm(config.pop("llm_path"))
else: else:
raise ValueError("One of `llm` or `llm_config` must be present.") raise ValueError("One of `llm` or `llm_path` must be present.")
if "prompt" in config: if "prompt" in config:
prompt_config = config.pop("prompt") prompt_config = config.pop("prompt")
@ -37,32 +52,403 @@ def _load_llm_chain(config: dict) -> LLMChain:
return LLMChain(llm=llm, prompt=prompt, **config) return LLMChain(llm=llm, prompt=prompt, **config)
type_to_loader_dict = {"llm_chain": _load_llm_chain} def _load_hyde_chain(config: dict, **kwargs: Any) -> HypotheticalDocumentEmbedder:
"""Load hypothetical document embedder chain from config dict."""
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config)
elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
if "embeddings" in kwargs:
embeddings = kwargs.pop("embeddings")
else:
raise ValueError("`embeddings` must be present.")
return HypotheticalDocumentEmbedder(
llm_chain=llm_chain, base_embeddings=embeddings, **config
)
def _load_stuff_documents_chain(config: dict, **kwargs: Any) -> StuffDocumentsChain:
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config)
elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm_chain` or `llm_chain_config` must be present.")
if not isinstance(llm_chain, LLMChain):
raise ValueError(f"Expected LLMChain, got {llm_chain}")
if "document_prompt" in config:
prompt_config = config.pop("document_prompt")
document_prompt = load_prompt_from_config(prompt_config)
elif "document_prompt_path" in config:
document_prompt = load_prompt(config.pop("document_prompt_path"))
else:
raise ValueError(
"One of `document_prompt` or `document_prompt_path` must be present."
)
return StuffDocumentsChain(
llm_chain=llm_chain, document_prompt=document_prompt, **config
)
def _load_map_reduce_documents_chain(
config: dict, **kwargs: Any
) -> MapReduceDocumentsChain:
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config)
elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm_chain` or `llm_chain_config` must be present.")
if not isinstance(llm_chain, LLMChain):
raise ValueError(f"Expected LLMChain, got {llm_chain}")
if "combine_document_chain" in config:
combine_document_chain_config = config.pop("combine_document_chain")
combine_document_chain = load_chain_from_config(combine_document_chain_config)
elif "combine_document_chain_path" in config:
combine_document_chain = load_chain(config.pop("combine_document_chain_path"))
else:
raise ValueError(
"One of `combine_document_chain` or "
"`combine_document_chain_path` must be present."
)
if "collapse_document_chain" in config:
collapse_document_chain_config = config.pop("collapse_document_chain")
if collapse_document_chain_config is None:
collapse_document_chain = None
else:
collapse_document_chain = load_chain_from_config(
collapse_document_chain_config
)
elif "collapse_document_chain_path" in config:
collapse_document_chain = load_chain(config.pop("collapse_document_chain_path"))
return MapReduceDocumentsChain(
llm_chain=llm_chain,
combine_document_chain=combine_document_chain,
collapse_document_chain=collapse_document_chain,
**config,
)
def _load_llm_bash_chain(config: dict, **kwargs: Any) -> LLMBashChain:
if "llm" in config:
llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config)
elif "llm_path" in config:
llm = load_llm(config.pop("llm_path"))
else:
raise ValueError("One of `llm` or `llm_path` must be present.")
if "prompt" in config:
prompt_config = config.pop("prompt")
prompt = load_prompt_from_config(prompt_config)
elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path"))
return LLMBashChain(llm=llm, prompt=prompt, **config)
def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain:
if "llm" in config:
llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config)
elif "llm_path" in config:
llm = load_llm(config.pop("llm_path"))
else:
raise ValueError("One of `llm` or `llm_path` must be present.")
if "create_draft_answer_prompt" in config:
create_draft_answer_prompt_config = config.pop("create_draft_answer_prompt")
create_draft_answer_prompt = load_prompt_from_config(
create_draft_answer_prompt_config
)
elif "create_draft_answer_prompt_path" in config:
create_draft_answer_prompt = load_prompt(
config.pop("create_draft_answer_prompt_path")
)
if "list_assertions_prompt" in config:
list_assertions_prompt_config = config.pop("list_assertions_prompt")
list_assertions_prompt = load_prompt_from_config(list_assertions_prompt_config)
elif "list_assertions_prompt_path" in config:
list_assertions_prompt = load_prompt(config.pop("list_assertions_prompt_path"))
if "check_assertions_prompt" in config:
check_assertions_prompt_config = config.pop("check_assertions_prompt")
check_assertions_prompt = load_prompt_from_config(
check_assertions_prompt_config
)
elif "check_assertions_prompt_path" in config:
check_assertions_prompt = load_prompt(
config.pop("check_assertions_prompt_path")
)
if "revised_answer_prompt" in config:
revised_answer_prompt_config = config.pop("revised_answer_prompt")
revised_answer_prompt = load_prompt_from_config(revised_answer_prompt_config)
elif "revised_answer_prompt_path" in config:
revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path"))
return LLMCheckerChain(
llm=llm,
create_draft_answer_prompt=create_draft_answer_prompt,
list_assertions_prompt=list_assertions_prompt,
check_assertions_prompt=check_assertions_prompt,
revised_answer_prompt=revised_answer_prompt,
**config,
)
def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
if "llm" in config:
llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config)
elif "llm_path" in config:
llm = load_llm(config.pop("llm_path"))
else:
raise ValueError("One of `llm` or `llm_path` must be present.")
if "prompt" in config:
prompt_config = config.pop("prompt")
prompt = load_prompt_from_config(prompt_config)
elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path"))
return LLMMathChain(llm=llm, prompt=prompt, **config)
def _load_map_rerank_documents_chain(
config: dict, **kwargs: Any
) -> MapRerankDocumentsChain:
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config)
elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm_chain` or `llm_chain_config` must be present.")
return MapRerankDocumentsChain(llm_chain=llm_chain, **config)
def _load_pal_chain(config: dict, **kwargs: Any) -> PALChain:
if "llm" in config:
llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config)
elif "llm_path" in config:
llm = load_llm(config.pop("llm_path"))
else:
raise ValueError("One of `llm` or `llm_path` must be present.")
if "prompt" in config:
prompt_config = config.pop("prompt")
prompt = load_prompt_from_config(prompt_config)
elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path"))
else:
raise ValueError("One of `prompt` or `prompt_path` must be present.")
return PALChain(llm=llm, prompt=prompt, **config)
def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocumentsChain:
if "initial_llm_chain" in config:
initial_llm_chain_config = config.pop("initial_llm_chain")
initial_llm_chain = load_chain_from_config(initial_llm_chain_config)
elif "initial_llm_chain_path" in config:
initial_llm_chain = load_chain(config.pop("initial_llm_chain_path"))
else:
raise ValueError(
"One of `initial_llm_chain` or `initial_llm_chain_config` must be present."
)
if "refine_llm_chain" in config:
refine_llm_chain_config = config.pop("refine_llm_chain")
refine_llm_chain = load_chain_from_config(refine_llm_chain_config)
elif "refine_llm_chain_path" in config:
refine_llm_chain = load_chain(config.pop("refine_llm_chain_path"))
else:
raise ValueError(
"One of `refine_llm_chain` or `refine_llm_chain_config` must be present."
)
if "document_prompt" in config:
prompt_config = config.pop("document_prompt")
document_prompt = load_prompt_from_config(prompt_config)
elif "document_prompt_path" in config:
document_prompt = load_prompt(config.pop("document_prompt_path"))
return RefineDocumentsChain(
initial_llm_chain=initial_llm_chain,
refine_llm_chain=refine_llm_chain,
document_prompt=document_prompt,
**config,
)
def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesChain:
if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config)
elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path"))
else:
raise ValueError(
"One of `combine_documents_chain` or "
"`combine_documents_chain_path` must be present."
)
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config)
def _load_sql_database_chain(config: dict, **kwargs: Any) -> SQLDatabaseChain:
if "database" in kwargs:
database = kwargs.pop("database")
else:
raise ValueError("`database` must be present.")
if "llm" in config:
llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config)
elif "llm_path" in config:
llm = load_llm(config.pop("llm_path"))
else:
raise ValueError("One of `llm` or `llm_path` must be present.")
if "prompt" in config:
prompt_config = config.pop("prompt")
prompt = load_prompt_from_config(prompt_config)
return SQLDatabaseChain(database=database, llm=llm, prompt=prompt, **config)
def _load_vector_db_qa_with_sources_chain(
config: dict, **kwargs: Any
) -> VectorDBQAWithSourcesChain:
if "vectorstore" in kwargs:
vectorstore = kwargs.pop("vectorstore")
else:
raise ValueError("`vectorstore` must be present.")
if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config)
elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path"))
else:
raise ValueError(
"One of `combine_documents_chain` or "
"`combine_documents_chain_path` must be present."
)
return VectorDBQAWithSourcesChain(
combine_documents_chain=combine_documents_chain,
vectorstore=vectorstore,
**config,
)
def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
if "vectorstore" in kwargs:
vectorstore = kwargs.pop("vectorstore")
else:
raise ValueError("`vectorstore` must be present.")
if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config)
elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path"))
else:
raise ValueError(
"One of `combine_documents_chain` or "
"`combine_documents_chain_path` must be present."
)
return VectorDBQA(
combine_documents_chain=combine_documents_chain,
vectorstore=vectorstore,
**config,
)
def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
if "api_request_chain" in config:
api_request_chain_config = config.pop("api_request_chain")
api_request_chain = load_chain_from_config(api_request_chain_config)
elif "api_request_chain_path" in config:
api_request_chain = load_chain(config.pop("api_request_chain_path"))
else:
raise ValueError(
"One of `api_request_chain` or `api_request_chain_path` must be present."
)
if "api_answer_chain" in config:
api_answer_chain_config = config.pop("api_answer_chain")
api_answer_chain = load_chain_from_config(api_answer_chain_config)
elif "api_answer_chain_path" in config:
api_answer_chain = load_chain(config.pop("api_answer_chain_path"))
else:
raise ValueError(
"One of `api_answer_chain` or `api_answer_chain_path` must be present."
)
if "requests_wrapper" in kwargs:
requests_wrapper = kwargs.pop("requests_wrapper")
else:
raise ValueError("`requests_wrapper` must be present.")
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper,
**config,
)
def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config)
elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
if "requests_wrapper" in kwargs:
requests_wrapper = kwargs.pop("requests_wrapper")
return LLMRequestsChain(
llm_chain=llm_chain, requests_wrapper=requests_wrapper, **config
)
else:
return LLMRequestsChain(llm_chain=llm_chain, **config)
type_to_loader_dict = {
"api_chain": _load_api_chain,
"hyde_chain": _load_hyde_chain,
"llm_chain": _load_llm_chain,
"llm_bash_chain": _load_llm_bash_chain,
"llm_checker_chain": _load_llm_checker_chain,
"llm_math_chain": _load_llm_math_chain,
"llm_requests_chain": _load_llm_requests_chain,
"pal_chain": _load_pal_chain,
"qa_with_sources_chain": _load_qa_with_sources_chain,
"stuff_documents_chain": _load_stuff_documents_chain,
"map_reduce_documents_chain": _load_map_reduce_documents_chain,
"map_rerank_documents_chain": _load_map_rerank_documents_chain,
"refine_documents_chain": _load_refine_documents_chain,
"sql_database_chain": _load_sql_database_chain,
"vector_db_qa_with_sources_chain": _load_vector_db_qa_with_sources_chain,
"vector_db_qa": _load_vector_db_qa,
}
def load_chain_from_config(config: dict) -> Chain: def load_chain_from_config(config: dict, **kwargs: Any) -> Chain:
"""Load chain from Config Dict.""" """Load chain from Config Dict."""
if "_type" not in config: if "_type" not in config:
raise ValueError("Must specify an chain Type in config") raise ValueError("Must specify a chain Type in config")
config_type = config.pop("_type") config_type = config.pop("_type")
if config_type not in type_to_loader_dict: if config_type not in type_to_loader_dict:
raise ValueError(f"Loading {config_type} chain not supported") raise ValueError(f"Loading {config_type} chain not supported")
chain_loader = type_to_loader_dict[config_type] chain_loader = type_to_loader_dict[config_type]
return chain_loader(config) return chain_loader(config, **kwargs)
def load_chain(path: Union[str, Path]) -> Chain: def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain:
"""Unified method for loading a chain from LangChainHub or local fs.""" """Unified method for loading a chain from LangChainHub or local fs."""
if isinstance(path, str) and path.startswith("lc://chains"): if isinstance(path, str) and path.startswith("lc://chains"):
path = os.path.relpath(path, "lc://chains/") path = os.path.relpath(path, "lc://chains/")
return _load_from_hub(path) return _load_from_hub(path, **kwargs)
else: else:
return _load_chain_from_file(path) return _load_chain_from_file(path, **kwargs)
def _load_chain_from_file(file: Union[str, Path]) -> Chain: def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
"""Load chain from file.""" """Load chain from file."""
# Convert file to Path object. # Convert file to Path object.
if isinstance(file, str): if isinstance(file, str):
@ -79,10 +465,10 @@ def _load_chain_from_file(file: Union[str, Path]) -> Chain:
else: else:
raise ValueError("File type must be json or yaml") raise ValueError("File type must be json or yaml")
# Load the chain from the config now. # Load the chain from the config now.
return load_chain_from_config(config) return load_chain_from_config(config, **kwargs)
def _load_from_hub(path: str) -> Chain: def _load_from_hub(path: str, **kwargs: Any) -> Chain:
"""Load chain from hub.""" """Load chain from hub."""
suffix = path.split(".")[-1] suffix = path.split(".")[-1]
if suffix not in {"json", "yaml"}: if suffix not in {"json", "yaml"}:
@ -95,4 +481,4 @@ def _load_from_hub(path: str) -> Chain:
file = tmpdirname + "/chain." + suffix file = tmpdirname + "/chain." + suffix
with open(file, "wb") as f: with open(file, "wb") as f:
f.write(r.content) f.write(r.content)
return _load_chain_from_file(file) return _load_chain_from_file(file, **kwargs)

@ -94,3 +94,7 @@ class NatBotChain(Chain, BaseModel):
self.input_browser_content_key: browser_content, self.input_browser_content_key: browser_content,
} }
return self(_inputs)[self.output_key] return self(_inputs)[self.output_key]
@property
def _chain_type(self) -> str:
return "nat_bot_chain"

@ -79,3 +79,7 @@ class PALChain(Chain, BaseModel):
get_answer_expr="print(answer)", get_answer_expr="print(answer)",
**kwargs, **kwargs,
) )
@property
def _chain_type(self) -> str:
return "pal_chain"

@ -126,3 +126,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]: def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
return inputs.pop(self.input_docs_key) return inputs.pop(self.input_docs_key)
@property
def _chain_type(self) -> str:
return "qa_with_sources_chain"

@ -13,7 +13,7 @@ from langchain.vectorstores.base import VectorStore
class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
"""Question-answering with sources over a vector database.""" """Question-answering with sources over a vector database."""
vectorstore: VectorStore vectorstore: VectorStore = Field(exclude=True)
"""Vector Database to connect to.""" """Vector Database to connect to."""
k: int = 4 k: int = 4
"""Number of results to return from store""" """Number of results to return from store"""
@ -50,3 +50,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
question, k=self.k, **self.search_kwargs question, k=self.k, **self.search_kwargs
) )
return self._reduce_tokens_below_limit(docs) return self._reduce_tokens_below_limit(docs)
@property
def _chain_type(self) -> str:
return "vector_db_qa_with_sources_chain"

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra, Field
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
@ -26,7 +26,7 @@ class SQLDatabaseChain(Chain, BaseModel):
llm: BaseLLM llm: BaseLLM
"""LLM wrapper to use.""" """LLM wrapper to use."""
database: SQLDatabase database: SQLDatabase = Field(exclude=True)
"""SQL Database to connect to.""" """SQL Database to connect to."""
prompt: BasePromptTemplate = PROMPT prompt: BasePromptTemplate = PROMPT
"""Prompt to use to translate natural language to SQL.""" """Prompt to use to translate natural language to SQL."""
@ -84,6 +84,10 @@ class SQLDatabaseChain(Chain, BaseModel):
self.callback_manager.on_text(final_result, color="green", verbose=self.verbose) self.callback_manager.on_text(final_result, color="green", verbose=self.verbose)
return {self.output_key: final_result} return {self.output_key: final_result}
@property
def _chain_type(self) -> str:
return "sql_database_chain"
class SQLDatabaseSequentialChain(Chain, BaseModel): class SQLDatabaseSequentialChain(Chain, BaseModel):
"""Chain for querying SQL database that is a sequential chain. """Chain for querying SQL database that is a sequential chain.
@ -153,3 +157,7 @@ class SQLDatabaseSequentialChain(Chain, BaseModel):
"table_names_to_use": table_names_to_use, "table_names_to_use": table_names_to_use,
} }
return self.sql_chain(new_inputs, return_only_outputs=True) return self.sql_chain(new_inputs, return_only_outputs=True)
@property
def _chain_type(self) -> str:
return "sql_database_sequential_chain"

@ -29,7 +29,7 @@ class VectorDBQA(Chain, BaseModel):
""" """
vectorstore: VectorStore vectorstore: VectorStore = Field(exclude=True)
"""Vector Database to connect to.""" """Vector Database to connect to."""
k: int = 4 k: int = 4
"""Number of documents to query for.""" """Number of documents to query for."""
@ -138,3 +138,8 @@ class VectorDBQA(Chain, BaseModel):
return {self.output_key: answer, "source_documents": docs} return {self.output_key: answer, "source_documents": docs}
else: else:
return {self.output_key: answer} return {self.output_key: answer}
@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "vector_db_qa"

Loading…
Cancel
Save