You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chains/api/base.py

138 lines
4.7 KiB
Python

"""Chain that makes API calls and summarizes the responses to answer a question."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts import BasePromptTemplate
from langchain.requests import RequestsWrapper
from langchain.chains import load_chain
class APIChain(Chain, BaseModel):
"""Chain that makes API calls and summarizes the responses to answer a question."""
api_request_chain: LLMChain
api_answer_chain: LLMChain
requests_wrapper: RequestsWrapper
api_docs: str
question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.question_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
@root_validator(pre=True)
def validate_api_request_prompt(cls, values: Dict) -> Dict:
"""Check that api request prompt expects the right variables."""
input_vars = values["api_request_chain"].prompt.input_variables
expected_vars = {"question", "api_docs"}
if set(input_vars) != expected_vars:
raise ValueError(
f"Input variables should be {expected_vars}, got {input_vars}"
)
return values
@root_validator(pre=True)
def validate_api_answer_prompt(cls, values: Dict) -> Dict:
"""Check that api answer prompt expects the right variables."""
input_vars = values["api_answer_chain"].prompt.input_variables
expected_vars = {"question", "api_docs", "api_url", "api_response"}
if set(input_vars) != expected_vars:
raise ValueError(
f"Input variables should be {expected_vars}, got {input_vars}"
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
question = inputs[self.question_key]
api_url = self.api_request_chain.predict(
question=question, api_docs=self.api_docs
)
self.callback_manager.on_text(
api_url, color="green", end="\n", verbose=self.verbose
)
api_response = self.requests_wrapper.run(api_url)
self.callback_manager.on_text(
api_response, color="yellow", end="\n", verbose=self.verbose
)
answer = self.api_answer_chain.predict(
question=question,
api_docs=self.api_docs,
api_url=api_url,
api_response=api_response,
)
return {self.output_key: answer}
@classmethod
def from_llm_and_api_docs(
cls,
llm: BaseLLM,
api_docs: str,
headers: Optional[dict] = None,
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT,
**kwargs: Any,
) -> APIChain:
"""Load chain from just an LLM and the api docs."""
get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt)
requests_wrapper = RequestsWrapper(headers=headers)
get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt)
return cls(
api_request_chain=get_request_chain,
api_answer_chain=get_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
**kwargs,
)
@property
def _chain_type(self) -> str:
return "api_chain"
@classmethod
def from_config(config: Dict) -> APIChain:
try:
api_request_chain_cfg = config.get("api_request_chain")
api_request_chain = load_chain(api_request_chain_cfg)
api_answer_chain_cfg = config.get("api_answer_chain")
api_answer_chain = load_chain(api_answer_chain_cfg)
request_headers = config.get("requests_wrapper").get("headers")
requests_wrapper = RequestsWrapper(headers=request_headers)
api_docs = config.get("api_docs")
question_key = config.get("question_key")
output_key = config.get("output_key")
except:
raise ValueError("Could not load API answer chain.")
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
question_key=question_key,
output_key=output_key,
)