|
|
|
@ -6,9 +6,10 @@ from typing import Any, Dict, List, Optional
|
|
|
|
|
import requests
|
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
|
|
|
|
|
|
from langchain import LLMChain
|
|
|
|
|
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.input import print_text
|
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -27,7 +28,7 @@ class APIChain(Chain, BaseModel):
|
|
|
|
|
|
|
|
|
|
api_request_chain: LLMChain
|
|
|
|
|
api_answer_chain: LLMChain
|
|
|
|
|
requests_chain: RequestsWrapper
|
|
|
|
|
requests_wrapper: RequestsWrapper
|
|
|
|
|
api_docs: str
|
|
|
|
|
question_key: str = "question" #: :meta private:
|
|
|
|
|
output_key: str = "output" #: :meta private:
|
|
|
|
@ -75,7 +76,11 @@ class APIChain(Chain, BaseModel):
|
|
|
|
|
api_url = self.api_request_chain.predict(
|
|
|
|
|
question=question, api_docs=self.api_docs
|
|
|
|
|
)
|
|
|
|
|
api_response = self.requests_chain.run(api_url)
|
|
|
|
|
if self.verbose:
|
|
|
|
|
print_text(api_url, color="green", end="\n")
|
|
|
|
|
api_response = self.requests_wrapper.run(api_url)
|
|
|
|
|
if self.verbose:
|
|
|
|
|
print_text(api_url, color="yellow", end="\n")
|
|
|
|
|
answer = self.api_answer_chain.predict(
|
|
|
|
|
question=question,
|
|
|
|
|
api_docs=self.api_docs,
|
|
|
|
@ -93,8 +98,8 @@ class APIChain(Chain, BaseModel):
|
|
|
|
|
requests_wrapper = RequestsWrapper(headers=headers)
|
|
|
|
|
get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT)
|
|
|
|
|
return cls(
|
|
|
|
|
request_chain=get_request_chain,
|
|
|
|
|
answer_chain=get_answer_chain,
|
|
|
|
|
api_request_chain=get_request_chain,
|
|
|
|
|
api_answer_chain=get_answer_chain,
|
|
|
|
|
requests_wrapper=requests_wrapper,
|
|
|
|
|
api_docs=api_docs,
|
|
|
|
|
**kwargs,
|
|
|
|
|