diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 3bb5f917..8f3e3de0 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -11,6 +11,7 @@ from langchain.chains.llm import LLMChain from langchain.llms.base import BaseLLM from langchain.prompts import BasePromptTemplate from langchain.requests import RequestsWrapper +from langchain.chains import load_chain class APIChain(Chain, BaseModel): @@ -102,3 +103,35 @@ class APIChain(Chain, BaseModel): api_docs=api_docs, **kwargs, ) + + @property + def _chain_type(self) -> str: + return "api_chain" + + @classmethod + def from_config(config: Dict) -> APIChain: + try: + api_request_chain_cfg = config.get("api_request_chain") + api_request_chain = load_chain(api_request_chain_cfg) + + api_answer_chain_cfg = config.get("api_answer_chain") + api_answer_chain = load_chain(api_answer_chain_cfg) + + request_headers = config.get("requests_wrapper").get("headers") + requests_wrapper = RequestsWrapper(headers=request_headers) + + api_docs = config.get("api_docs") + question_key = config.get("question_key") + output_key = config.get("output_key") + + except: + raise ValueError("Could not load API answer chain.") + + return APIChain( + api_request_chain=api_request_chain, + api_answer_chain=api_answer_chain, + requests_wrapper=requests_wrapper, + api_docs=api_docs, + question_key=question_key, + output_key=output_key, + ) diff --git a/langchain/chains/base.py b/langchain/chains/base.py index e30cdb15..72943ff3 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -226,3 +226,7 @@ class Chain(BaseModel, ABC): yaml.dump(chain_dict, f, default_flow_style=False) else: raise ValueError(f"{save_path} must be json or yaml") + + @classmethod + def from_config(config: Dict) -> "Chain": + raise NotImplementedError("Abstract method.")