Harrison/improve usability of api chain (#247)

improve usability of api chain
harrison/use_output_parser^2
Harrison Chase 2 years ago committed by GitHub
parent c897bd6cbd
commit a9ce04201f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,5 @@
"""Chains are easily reusable components which can be linked together."""
from langchain.chains.api.base import APIChain
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.base import LLMMathChain
@ -22,4 +23,5 @@ __all__ = [
"QAWithSourcesChain",
"VectorDBQAWithSourcesChain",
"PALChain",
"APIChain",
]

@ -6,9 +6,10 @@ from typing import Any, Dict, List, Optional
import requests
from pydantic import BaseModel, root_validator
from langchain import LLMChain
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import print_text
from langchain.llms.base import LLM
@ -27,7 +28,7 @@ class APIChain(Chain, BaseModel):
api_request_chain: LLMChain
api_answer_chain: LLMChain
requests_chain: RequestsWrapper
requests_wrapper: RequestsWrapper
api_docs: str
question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private:
@ -75,7 +76,11 @@ class APIChain(Chain, BaseModel):
api_url = self.api_request_chain.predict(
question=question, api_docs=self.api_docs
)
api_response = self.requests_chain.run(api_url)
if self.verbose:
print_text(api_url, color="green", end="\n")
api_response = self.requests_wrapper.run(api_url)
if self.verbose:
print_text(api_url, color="yellow", end="\n")
answer = self.api_answer_chain.predict(
question=question,
api_docs=self.api_docs,
@ -93,8 +98,8 @@ class APIChain(Chain, BaseModel):
requests_wrapper = RequestsWrapper(headers=headers)
get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT)
return cls(
request_chain=get_request_chain,
answer_chain=get_answer_chain,
api_request_chain=get_request_chain,
api_answer_chain=get_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
**kwargs,

@ -68,11 +68,11 @@ def fake_llm_api_chain(test_api_data: dict) -> APIChain:
fake_llm = FakeLLM(queries=queries)
api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
requests_chain = FakeRequestsChain(output=TEST_API_RESPONSE)
requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE)
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_chain=requests_chain,
requests_wrapper=requests_wrapper,
api_docs=TEST_API_DOCS,
)

Loading…
Cancel
Save