You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chains/api/base.py

97 lines
3.4 KiB
Python

"""Chain that makes API calls and summarizes the responses to answer a question."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import print_text
from langchain.llms.base import BaseLLM
from langchain.requests import RequestsWrapper
class APIChain(Chain, BaseModel):
"""Chain that makes API calls and summarizes the responses to answer a question."""
api_request_chain: LLMChain
api_answer_chain: LLMChain
requests_wrapper: RequestsWrapper
api_docs: str
question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.question_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
@root_validator(pre=True)
def validate_api_request_prompt(cls, values: Dict) -> Dict:
"""Check that api request prompt expects the right variables."""
input_vars = values["api_request_chain"].prompt.input_variables
expected_vars = {"question", "api_docs"}
if set(input_vars) != expected_vars:
raise ValueError(
f"Input variables should be {expected_vars}, got {input_vars}"
)
return values
@root_validator(pre=True)
def validate_api_answer_prompt(cls, values: Dict) -> Dict:
"""Check that api answer prompt expects the right variables."""
input_vars = values["api_answer_chain"].prompt.input_variables
expected_vars = {"question", "api_docs", "api_url", "api_response"}
if set(input_vars) != expected_vars:
raise ValueError(
f"Input variables should be {expected_vars}, got {input_vars}"
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
question = inputs[self.question_key]
api_url = self.api_request_chain.predict(
question=question, api_docs=self.api_docs
)
if self.verbose:
print_text(api_url, color="green", end="\n")
api_response = self.requests_wrapper.run(api_url)
if self.verbose:
print_text(api_response, color="yellow", end="\n")
answer = self.api_answer_chain.predict(
question=question,
api_docs=self.api_docs,
api_url=api_url,
api_response=api_response,
)
return {self.output_key: answer}
@classmethod
def from_llm_and_api_docs(
cls, llm: BaseLLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
) -> APIChain:
"""Load chain from just an LLM and the api docs."""
get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT)
requests_wrapper = RequestsWrapper(headers=headers)
get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT)
return cls(
api_request_chain=get_request_chain,
api_answer_chain=get_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
**kwargs,
)