serializing api chains

scad/api-chain
scadEfUr 1 year ago
parent e3df8ab6dc
commit 5f7e8196c6

@ -11,6 +11,7 @@ from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts import BasePromptTemplate
from langchain.requests import RequestsWrapper
from langchain.chains import load_chain
class APIChain(Chain, BaseModel):
@ -102,3 +103,35 @@ class APIChain(Chain, BaseModel):
api_docs=api_docs,
**kwargs,
)
@property
def _chain_type(self) -> str:
return "api_chain"
@classmethod
def from_config(config: Dict) -> APIChain:
try:
api_request_chain_cfg = config.get("api_request_chain")
api_request_chain = load_chain(api_request_chain_cfg)
api_answer_chain_cfg = config.get("api_answer_chain")
api_answer_chain = load_chain(api_answer_chain_cfg)
request_headers = config.get("requests_wrapper").get("headers")
requests_wrapper = RequestsWrapper(headers=request_headers)
api_docs = config.get("api_docs")
question_key = config.get("question_key")
output_key = config.get("output_key")
except:
raise ValueError("Could not load API answer chain.")
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
question_key=question_key,
output_key=output_key,
)

@ -226,3 +226,7 @@ class Chain(BaseModel, ABC):
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
@classmethod
def from_config(config: Dict) -> "Chain":
raise NotImplementedError("Abstract method.")

Loading…
Cancel
Save