serializing api chains

scad/api-chain
scadEfUr 1 year ago
parent e3df8ab6dc
commit 5f7e8196c6

@ -11,6 +11,7 @@ from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
from langchain.requests import RequestsWrapper from langchain.requests import RequestsWrapper
from langchain.chains import load_chain
class APIChain(Chain, BaseModel): class APIChain(Chain, BaseModel):
@ -102,3 +103,35 @@ class APIChain(Chain, BaseModel):
api_docs=api_docs, api_docs=api_docs,
**kwargs, **kwargs,
) )
@property
def _chain_type(self) -> str:
return "api_chain"
@classmethod
def from_config(config: Dict) -> APIChain:
try:
api_request_chain_cfg = config.get("api_request_chain")
api_request_chain = load_chain(api_request_chain_cfg)
api_answer_chain_cfg = config.get("api_answer_chain")
api_answer_chain = load_chain(api_answer_chain_cfg)
request_headers = config.get("requests_wrapper").get("headers")
requests_wrapper = RequestsWrapper(headers=request_headers)
api_docs = config.get("api_docs")
question_key = config.get("question_key")
output_key = config.get("output_key")
except:
raise ValueError("Could not load API answer chain.")
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
question_key=question_key,
output_key=output_key,
)

@ -226,3 +226,7 @@ class Chain(BaseModel, ABC):
yaml.dump(chain_dict, f, default_flow_style=False) yaml.dump(chain_dict, f, default_flow_style=False)
else: else:
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or yaml")
@classmethod
def from_config(config: Dict) -> "Chain":
raise NotImplementedError("Abstract method.")

Loading…
Cancel
Save