Merge branch 'master' into harrison/use_output_parser

This commit is contained in:
Harrison Chase 2022-12-03 13:13:34 -08:00
commit cc606180cd
8 changed files with 233 additions and 2 deletions

View File

@ -19,6 +19,9 @@ GPT Index is a project consisting of a set of data structures that are created u
### [Grover's Algorithm](https://github.com/JavaFXpert/llm-grovers-search-party) ### [Grover's Algorithm](https://github.com/JavaFXpert/llm-grovers-search-party)
Leveraging Qiskit, OpenAI and LangChain to demonstrate Grover's algorithm Leveraging Qiskit, OpenAI and LangChain to demonstrate Grover's algorithm
### [ReAct TextWorld](https://colab.research.google.com/drive/19WTIWC3prw5LDMHmRMvqNV2loD9FHls6?usp=sharing)
Leveraging the ReActTextWorldAgent to play TextWorld with an LLM!
## Not Open Source ## Not Open Source

View File

@ -160,7 +160,7 @@
} }
], ],
"source": [ "source": [
"agent.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")" "agent.run(\"How old is Olivia Wilde's boyfriend? What is that number raised to the 0.23 power?\")"
] ]
}, },
{ {

View File

@ -1 +1 @@
0.0.26 0.0.27

View File

@ -1,4 +1,5 @@
"""Chains are easily reusable components which can be linked together.""" """Chains are easily reusable components which can be linked together."""
from langchain.chains.api.base import APIChain
from langchain.chains.conversation.base import ConversationChain from langchain.chains.conversation.base import ConversationChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_math.base import LLMMathChain
@ -22,4 +23,5 @@ __all__ = [
"QAWithSourcesChain", "QAWithSourcesChain",
"VectorDBQAWithSourcesChain", "VectorDBQAWithSourcesChain",
"PALChain", "PALChain",
"APIChain",
] ]

View File

@ -0,0 +1 @@
"""Chain that makes API calls and summarizes the responses to answer a question."""

View File

@ -0,0 +1,106 @@
"""Chain that makes API calls and summarizes the responses to answer a question."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import requests
from pydantic import BaseModel, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import print_text
from langchain.llms.base import LLM
class RequestsWrapper(BaseModel):
"""Lightweight wrapper to partial out everything except the url to hit."""
headers: Optional[dict] = None
def run(self, url: str) -> str:
"""Hit the URL and return the text."""
return requests.get(url, headers=self.headers).text
class APIChain(Chain, BaseModel):
"""Chain that makes API calls and summarizes the responses to answer a question."""
api_request_chain: LLMChain
api_answer_chain: LLMChain
requests_wrapper: RequestsWrapper
api_docs: str
question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.question_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
@root_validator(pre=True)
def validate_api_request_prompt(cls, values: Dict) -> Dict:
"""Check that api request prompt expects the right variables."""
input_vars = values["api_request_chain"].prompt.input_variables
expected_vars = {"question", "api_docs"}
if set(input_vars) != expected_vars:
raise ValueError(
f"Input variables should be {expected_vars}, got {input_vars}"
)
return values
@root_validator(pre=True)
def validate_api_answer_prompt(cls, values: Dict) -> Dict:
"""Check that api answer prompt expects the right variables."""
input_vars = values["api_answer_chain"].prompt.input_variables
expected_vars = {"question", "api_docs", "api_url", "api_response"}
if set(input_vars) != expected_vars:
raise ValueError(
f"Input variables should be {expected_vars}, got {input_vars}"
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
question = inputs[self.question_key]
api_url = self.api_request_chain.predict(
question=question, api_docs=self.api_docs
)
if self.verbose:
print_text(api_url, color="green", end="\n")
api_response = self.requests_wrapper.run(api_url)
if self.verbose:
print_text(api_url, color="yellow", end="\n")
answer = self.api_answer_chain.predict(
question=question,
api_docs=self.api_docs,
api_url=api_url,
api_response=api_response,
)
return {self.output_key: answer}
@classmethod
def from_llm_and_api_docs(
cls, llm: LLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
) -> APIChain:
"""Load chain from just an LLM and the api docs."""
get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT)
requests_wrapper = RequestsWrapper(headers=headers)
get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT)
return cls(
api_request_chain=get_request_chain,
api_answer_chain=get_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=api_docs,
**kwargs,
)

View File

@ -0,0 +1,35 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate
API_URL_PROMPT_TEMPLATE = """You are given the below API Documentation:
{api_docs}
Using this documentation, generate the full API url to call for answering this question: {question}
API url: """
API_URL_PROMPT = PromptTemplate(
input_variables=[
"api_docs",
"question",
],
template=API_URL_PROMPT_TEMPLATE,
)
API_RESPONSE_PROMPT_TEMPLATE = (
API_URL_PROMPT_TEMPLATE
+ """ {api_url}
Here is the response from the API:
{api_response}
Summarize this response to answer the original question.
Summary:"""
)
API_RESPONSE_PROMPT = PromptTemplate(
input_variables=["api_docs", "question", "api_url", "api_response"],
template=API_RESPONSE_PROMPT_TEMPLATE,
)

View File

@ -0,0 +1,84 @@
"""Test LLM Math functionality."""
import json
import pytest
from langchain import LLMChain
from langchain.chains.api.base import APIChain, RequestsWrapper
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeRequestsChain(RequestsWrapper):
"""Fake requests chain just for testing purposes."""
output: str
def run(self, url: str) -> str:
"""Just return the specified output."""
return self.output
@pytest.fixture
def test_api_data() -> dict:
"""Fake api data to use for testing."""
api_docs = """
This API endpoint will search the notes for a user.
Endpoint: https://thisapidoesntexist.com
GET /api/notes
Query parameters:
q | string | The search term for notes
"""
return {
"api_docs": api_docs,
"question": "Search for notes containing langchain",
"api_url": "https://thisapidoesntexist.com/api/notes?q=langchain",
"api_response": json.dumps(
{
"success": True,
"results": [{"id": 1, "content": "Langchain is awesome!"}],
}
),
"api_summary": "There is 1 note about langchain.",
}
@pytest.fixture
def fake_llm_api_chain(test_api_data: dict) -> APIChain:
"""Fake LLM API chain for testing."""
TEST_API_DOCS = test_api_data["api_docs"]
TEST_QUESTION = test_api_data["question"]
TEST_URL = test_api_data["api_url"]
TEST_API_RESPONSE = test_api_data["api_response"]
TEST_API_SUMMARY = test_api_data["api_summary"]
api_url_query_prompt = API_URL_PROMPT.format(
api_docs=TEST_API_DOCS, question=TEST_QUESTION
)
api_response_prompt = API_RESPONSE_PROMPT.format(
api_docs=TEST_API_DOCS,
question=TEST_QUESTION,
api_url=TEST_URL,
api_response=TEST_API_RESPONSE,
)
queries = {api_url_query_prompt: TEST_URL, api_response_prompt: TEST_API_SUMMARY}
fake_llm = FakeLLM(queries=queries)
api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE)
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper,
api_docs=TEST_API_DOCS,
)
def test_api_question(fake_llm_api_chain: APIChain, test_api_data: dict) -> None:
"""Test simple question that needs API access."""
question = test_api_data["question"]
output = fake_llm_api_chain.run(question)
assert output == test_api_data["api_summary"]