Merge branch 'master' into harrison/use_output_parser

This commit is contained in:
Harrison Chase 2022-12-03 13:13:34 -08:00
commit cc606180cd
8 changed files with 233 additions and 2 deletions

View File

@ -19,6 +19,9 @@ GPT Index is a project consisting of a set of data structures that are created u
### [Grover's Algorithm](https://github.com/JavaFXpert/llm-grovers-search-party)
Leveraging Qiskit, OpenAI and LangChain to demonstrate Grover's algorithm
### [ReAct TextWorld](https://colab.research.google.com/drive/19WTIWC3prw5LDMHmRMvqNV2loD9FHls6?usp=sharing)
Leveraging the ReActTextWorldAgent to play TextWorld with an LLM!
## Not Open Source

View File

@ -160,7 +160,7 @@
}
],
"source": [
"agent.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")"
"agent.run(\"How old is Olivia Wilde's boyfriend? What is that number raised to the 0.23 power?\")"
]
},
{

View File

@ -1 +1 @@
0.0.26
0.0.27

View File

@ -1,4 +1,5 @@
"""Chains are easily reusable components which can be linked together."""
from langchain.chains.api.base import APIChain
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.base import LLMMathChain
@ -22,4 +23,5 @@ __all__ = [
"QAWithSourcesChain",
"VectorDBQAWithSourcesChain",
"PALChain",
"APIChain",
]

View File

@ -0,0 +1 @@
"""Chain that makes API calls and summarizes the responses to answer a question."""

View File

@ -0,0 +1,106 @@
"""Chain that makes API calls and summarizes the responses to answer a question."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import requests
from pydantic import BaseModel, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import print_text
from langchain.llms.base import LLM
class RequestsWrapper(BaseModel):
"""Lightweight wrapper to partial out everything except the url to hit."""
headers: Optional[dict] = None
def run(self, url: str) -> str:
"""Hit the URL and return the text."""
return requests.get(url, headers=self.headers).text
class APIChain(Chain, BaseModel):
"""Chain that makes API calls and summarizes the responses to answer a question."""
api_request_chain: LLMChain
api_answer_chain: LLMChain
requests_wrapper: RequestsWrapper
api_docs: str
question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.question_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
@root_validator(pre=True)
def validate_api_request_prompt(cls, values: Dict) -> Dict:
"""Check that api request prompt expects the right variables."""
input_vars = values["api_request_chain"].prompt.input_variables
expected_vars = {"question", "api_docs"}
if set(input_vars) != expected_vars:
raise ValueError(
f"Input variables should be {expected_vars}, got {input_vars}"
)
return values
@root_validator(pre=True)
def validate_api_answer_prompt(cls, values: Dict) -> Dict:
"""Check that api answer prompt expects the right variables."""
input_vars = values["api_answer_chain"].prompt.input_variables
expected_vars = {"question", "api_docs", "api_url", "api_response"}
if set(input_vars) != expected_vars:
raise ValueError(
f"Input variables should be {expected_vars}, got {input_vars}"
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
question = inputs[self.question_key]
api_url = self.api_request_chain.predict(
question=question, api_docs=self.api_docs
)
if self.verbose:
print_text(api_url, color="green", end="\n")
api_response = self.requests_wrapper.run(api_url)
if self.verbose:
print_text(api_url, color="yellow", end="\n")
answer = self.api_answer_chain.predict(
question=question,
api_docs=self.api_docs,
api_url=api_url,
api_response=api_response,
)
return {self.output_key: answer}
@classmethod
def from_llm_and_api_docs(
cls, llm: LLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
) -> APIChain:
"""Load chain from just an LLM and the api docs."""
get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT)
requests_wrapper = RequestsWrapper(headers=headers)
get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT)
return cls(
api_request_chain=get_request_chain,
api_answer_chain=get_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
**kwargs,
)

View File

@ -0,0 +1,35 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate
API_URL_PROMPT_TEMPLATE = """You are given the below API Documentation:
{api_docs}
Using this documentation, generate the full API url to call for answering this question: {question}
API url: """
API_URL_PROMPT = PromptTemplate(
input_variables=[
"api_docs",
"question",
],
template=API_URL_PROMPT_TEMPLATE,
)
API_RESPONSE_PROMPT_TEMPLATE = (
API_URL_PROMPT_TEMPLATE
+ """ {api_url}
Here is the response from the API:
{api_response}
Summarize this response to answer the original question.
Summary:"""
)
API_RESPONSE_PROMPT = PromptTemplate(
input_variables=["api_docs", "question", "api_url", "api_response"],
template=API_RESPONSE_PROMPT_TEMPLATE,
)

View File

@ -0,0 +1,84 @@
"""Test LLM Math functionality."""
import json
import pytest
from langchain import LLMChain
from langchain.chains.api.base import APIChain, RequestsWrapper
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeRequestsChain(RequestsWrapper):
"""Fake requests chain just for testing purposes."""
output: str
def run(self, url: str) -> str:
"""Just return the specified output."""
return self.output
@pytest.fixture
def test_api_data() -> dict:
"""Fake api data to use for testing."""
api_docs = """
This API endpoint will search the notes for a user.
Endpoint: https://thisapidoesntexist.com
GET /api/notes
Query parameters:
q | string | The search term for notes
"""
return {
"api_docs": api_docs,
"question": "Search for notes containing langchain",
"api_url": "https://thisapidoesntexist.com/api/notes?q=langchain",
"api_response": json.dumps(
{
"success": True,
"results": [{"id": 1, "content": "Langchain is awesome!"}],
}
),
"api_summary": "There is 1 note about langchain.",
}
@pytest.fixture
def fake_llm_api_chain(test_api_data: dict) -> APIChain:
"""Fake LLM API chain for testing."""
TEST_API_DOCS = test_api_data["api_docs"]
TEST_QUESTION = test_api_data["question"]
TEST_URL = test_api_data["api_url"]
TEST_API_RESPONSE = test_api_data["api_response"]
TEST_API_SUMMARY = test_api_data["api_summary"]
api_url_query_prompt = API_URL_PROMPT.format(
api_docs=TEST_API_DOCS, question=TEST_QUESTION
)
api_response_prompt = API_RESPONSE_PROMPT.format(
api_docs=TEST_API_DOCS,
question=TEST_QUESTION,
api_url=TEST_URL,
api_response=TEST_API_RESPONSE,
)
queries = {api_url_query_prompt: TEST_URL, api_response_prompt: TEST_API_SUMMARY}
fake_llm = FakeLLM(queries=queries)
api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE)
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=TEST_API_DOCS,
)
def test_api_question(fake_llm_api_chain: APIChain, test_api_data: dict) -> None:
"""Test simple question that needs API access."""
question = test_api_data["question"]
output = fake_llm_api_chain.run(question)
assert output == test_api_data["api_summary"]