forked from Archives/langchain
Merge branch 'master' into harrison/use_output_parser
This commit is contained in:
commit
cc606180cd
@ -19,6 +19,9 @@ GPT Index is a project consisting of a set of data structures that are created u
|
|||||||
### [Grover's Algorithm](https://github.com/JavaFXpert/llm-grovers-search-party)
|
### [Grover's Algorithm](https://github.com/JavaFXpert/llm-grovers-search-party)
|
||||||
Leveraging Qiskit, OpenAI and LangChain to demonstrate Grover's algorithm
|
Leveraging Qiskit, OpenAI and LangChain to demonstrate Grover's algorithm
|
||||||
|
|
||||||
|
### [ReAct TextWorld](https://colab.research.google.com/drive/19WTIWC3prw5LDMHmRMvqNV2loD9FHls6?usp=sharing)
|
||||||
|
Leveraging the ReActTextWorldAgent to play TextWorld with an LLM!
|
||||||
|
|
||||||
|
|
||||||
## Not Open Source
|
## Not Open Source
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"agent.run(\"What is the age of Olivia Wilde's boyfriend raised to the 0.23 power?\")"
|
"agent.run(\"How old is Olivia Wilde's boyfriend? What is that number raised to the 0.23 power?\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1 +1 @@
|
|||||||
0.0.26
|
0.0.27
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Chains are easily reusable components which can be linked together."""
|
"""Chains are easily reusable components which can be linked together."""
|
||||||
|
from langchain.chains.api.base import APIChain
|
||||||
from langchain.chains.conversation.base import ConversationChain
|
from langchain.chains.conversation.base import ConversationChain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.llm_math.base import LLMMathChain
|
from langchain.chains.llm_math.base import LLMMathChain
|
||||||
@ -22,4 +23,5 @@ __all__ = [
|
|||||||
"QAWithSourcesChain",
|
"QAWithSourcesChain",
|
||||||
"VectorDBQAWithSourcesChain",
|
"VectorDBQAWithSourcesChain",
|
||||||
"PALChain",
|
"PALChain",
|
||||||
|
"APIChain",
|
||||||
]
|
]
|
||||||
|
1
langchain/chains/api/__init__.py
Normal file
1
langchain/chains/api/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Chain that makes API calls and summarizes the responses to answer a question."""
|
106
langchain/chains/api/base.py
Normal file
106
langchain/chains/api/base.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
"""Chain that makes API calls and summarizes the responses to answer a question."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.input import print_text
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
|
||||||
|
class RequestsWrapper(BaseModel):
|
||||||
|
"""Lightweight wrapper to partial out everything except the url to hit."""
|
||||||
|
|
||||||
|
headers: Optional[dict] = None
|
||||||
|
|
||||||
|
def run(self, url: str) -> str:
|
||||||
|
"""Hit the URL and return the text."""
|
||||||
|
return requests.get(url, headers=self.headers).text
|
||||||
|
|
||||||
|
|
||||||
|
class APIChain(Chain, BaseModel):
|
||||||
|
"""Chain that makes API calls and summarizes the responses to answer a question."""
|
||||||
|
|
||||||
|
api_request_chain: LLMChain
|
||||||
|
api_answer_chain: LLMChain
|
||||||
|
requests_wrapper: RequestsWrapper
|
||||||
|
api_docs: str
|
||||||
|
question_key: str = "question" #: :meta private:
|
||||||
|
output_key: str = "output" #: :meta private:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Expect input key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.question_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Expect output key.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.output_key]
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_api_request_prompt(cls, values: Dict) -> Dict:
|
||||||
|
"""Check that api request prompt expects the right variables."""
|
||||||
|
input_vars = values["api_request_chain"].prompt.input_variables
|
||||||
|
expected_vars = {"question", "api_docs"}
|
||||||
|
if set(input_vars) != expected_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input variables should be {expected_vars}, got {input_vars}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_api_answer_prompt(cls, values: Dict) -> Dict:
|
||||||
|
"""Check that api answer prompt expects the right variables."""
|
||||||
|
input_vars = values["api_answer_chain"].prompt.input_variables
|
||||||
|
expected_vars = {"question", "api_docs", "api_url", "api_response"}
|
||||||
|
if set(input_vars) != expected_vars:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input variables should be {expected_vars}, got {input_vars}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
question = inputs[self.question_key]
|
||||||
|
api_url = self.api_request_chain.predict(
|
||||||
|
question=question, api_docs=self.api_docs
|
||||||
|
)
|
||||||
|
if self.verbose:
|
||||||
|
print_text(api_url, color="green", end="\n")
|
||||||
|
api_response = self.requests_wrapper.run(api_url)
|
||||||
|
if self.verbose:
|
||||||
|
print_text(api_url, color="yellow", end="\n")
|
||||||
|
answer = self.api_answer_chain.predict(
|
||||||
|
question=question,
|
||||||
|
api_docs=self.api_docs,
|
||||||
|
api_url=api_url,
|
||||||
|
api_response=api_response,
|
||||||
|
)
|
||||||
|
return {self.output_key: answer}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm_and_api_docs(
|
||||||
|
cls, llm: LLM, api_docs: str, headers: Optional[dict] = None, **kwargs: Any
|
||||||
|
) -> APIChain:
|
||||||
|
"""Load chain from just an LLM and the api docs."""
|
||||||
|
get_request_chain = LLMChain(llm=llm, prompt=API_URL_PROMPT)
|
||||||
|
requests_wrapper = RequestsWrapper(headers=headers)
|
||||||
|
get_answer_chain = LLMChain(llm=llm, prompt=API_RESPONSE_PROMPT)
|
||||||
|
return cls(
|
||||||
|
api_request_chain=get_request_chain,
|
||||||
|
api_answer_chain=get_answer_chain,
|
||||||
|
requests_wrapper=requests_wrapper,
|
||||||
|
api_docs=api_docs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
35
langchain/chains/api/prompt.py
Normal file
35
langchain/chains/api/prompt.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
API_URL_PROMPT_TEMPLATE = """You are given the below API Documentation:
|
||||||
|
|
||||||
|
{api_docs}
|
||||||
|
|
||||||
|
Using this documentation, generate the full API url to call for answering this question: {question}
|
||||||
|
|
||||||
|
API url: """
|
||||||
|
API_URL_PROMPT = PromptTemplate(
|
||||||
|
input_variables=[
|
||||||
|
"api_docs",
|
||||||
|
"question",
|
||||||
|
],
|
||||||
|
template=API_URL_PROMPT_TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
API_RESPONSE_PROMPT_TEMPLATE = (
|
||||||
|
API_URL_PROMPT_TEMPLATE
|
||||||
|
+ """ {api_url}
|
||||||
|
|
||||||
|
Here is the response from the API:
|
||||||
|
|
||||||
|
{api_response}
|
||||||
|
|
||||||
|
Summarize this response to answer the original question.
|
||||||
|
|
||||||
|
Summary:"""
|
||||||
|
)
|
||||||
|
|
||||||
|
API_RESPONSE_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["api_docs", "question", "api_url", "api_response"],
|
||||||
|
template=API_RESPONSE_PROMPT_TEMPLATE,
|
||||||
|
)
|
84
tests/unit_tests/chains/test_api.py
Normal file
84
tests/unit_tests/chains/test_api.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
"""Test LLM Math functionality."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain import LLMChain
|
||||||
|
from langchain.chains.api.base import APIChain, RequestsWrapper
|
||||||
|
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
|
class FakeRequestsChain(RequestsWrapper):
|
||||||
|
"""Fake requests chain just for testing purposes."""
|
||||||
|
|
||||||
|
output: str
|
||||||
|
|
||||||
|
def run(self, url: str) -> str:
|
||||||
|
"""Just return the specified output."""
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_api_data() -> dict:
|
||||||
|
"""Fake api data to use for testing."""
|
||||||
|
api_docs = """
|
||||||
|
This API endpoint will search the notes for a user.
|
||||||
|
|
||||||
|
Endpoint: https://thisapidoesntexist.com
|
||||||
|
GET /api/notes
|
||||||
|
|
||||||
|
Query parameters:
|
||||||
|
q | string | The search term for notes
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"api_docs": api_docs,
|
||||||
|
"question": "Search for notes containing langchain",
|
||||||
|
"api_url": "https://thisapidoesntexist.com/api/notes?q=langchain",
|
||||||
|
"api_response": json.dumps(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"results": [{"id": 1, "content": "Langchain is awesome!"}],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"api_summary": "There is 1 note about langchain.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_llm_api_chain(test_api_data: dict) -> APIChain:
|
||||||
|
"""Fake LLM API chain for testing."""
|
||||||
|
TEST_API_DOCS = test_api_data["api_docs"]
|
||||||
|
TEST_QUESTION = test_api_data["question"]
|
||||||
|
TEST_URL = test_api_data["api_url"]
|
||||||
|
TEST_API_RESPONSE = test_api_data["api_response"]
|
||||||
|
TEST_API_SUMMARY = test_api_data["api_summary"]
|
||||||
|
|
||||||
|
api_url_query_prompt = API_URL_PROMPT.format(
|
||||||
|
api_docs=TEST_API_DOCS, question=TEST_QUESTION
|
||||||
|
)
|
||||||
|
api_response_prompt = API_RESPONSE_PROMPT.format(
|
||||||
|
api_docs=TEST_API_DOCS,
|
||||||
|
question=TEST_QUESTION,
|
||||||
|
api_url=TEST_URL,
|
||||||
|
api_response=TEST_API_RESPONSE,
|
||||||
|
)
|
||||||
|
queries = {api_url_query_prompt: TEST_URL, api_response_prompt: TEST_API_SUMMARY}
|
||||||
|
fake_llm = FakeLLM(queries=queries)
|
||||||
|
api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
|
||||||
|
api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
|
||||||
|
requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE)
|
||||||
|
return APIChain(
|
||||||
|
api_request_chain=api_request_chain,
|
||||||
|
api_answer_chain=api_answer_chain,
|
||||||
|
requests_wrapper=requests_wrapper,
|
||||||
|
api_docs=TEST_API_DOCS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_question(fake_llm_api_chain: APIChain, test_api_data: dict) -> None:
|
||||||
|
"""Test simple question that needs API access."""
|
||||||
|
question = test_api_data["question"]
|
||||||
|
output = fake_llm_api_chain.run(question)
|
||||||
|
assert output == test_api_data["api_summary"]
|
Loading…
Reference in New Issue
Block a user