LLMRequestsChain (#267)

harrison/promot-mrkl
Harrison Chase 1 year ago committed by GitHub
parent 68666d6a22
commit 28be37f470
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,123 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "dd7ec7af",
"metadata": {},
"source": [
"# LLMRequestsChain\n",
"\n",
"Using the request library to get HTML results from a URL and then an LLM to parse results"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "dd8eae75",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain.chains import LLMRequestsChain, LLMChain"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "65bf324e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"\n",
"template = \"\"\"Between >>> and <<< are the raw search result text from google.\n",
"Extract the answer to the question '{query}' or say \"not found\" if the information is not contained.\n",
"Use the format\n",
"Extracted:<answer or \"not found\">\n",
">>> {requests_result} <<<\n",
"Extracted:\"\"\"\n",
"\n",
"PROMPT = PromptTemplate(\n",
" input_variables=[\"query\", \"requests_result\"],\n",
" template=template,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f36ae0d8",
"metadata": {},
"outputs": [],
"source": [
"chain = LLMRequestsChain(llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=PROMPT))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b5d22d9d",
"metadata": {},
"outputs": [],
"source": [
"question = \"What are the Three (3) biggest countries, and their respective sizes?\"\n",
"inputs = {\n",
" \"query\": question,\n",
" \"url\": \"https://www.google.com/search?q=\" + question.replace(\" \", \"+\")\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2ea81168",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'query': 'What are the Three (3) biggest countries, and their respective sizes?',\n",
" 'url': 'https://www.google.com/search?q=What+are+the+Three+(3)+biggest+countries,+and+their+respective+sizes?',\n",
" 'output': ' Russia (17,098,242 sq km), Canada (9,984,670 sq km), China (9,706,961 sq km)'}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain(inputs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db8f2b6d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -3,6 +3,7 @@ from langchain.chains.api.base import APIChain
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.pal.base import PALChain
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
@ -22,4 +23,5 @@ __all__ = [
"VectorDBQAWithSourcesChain",
"PALChain",
"APIChain",
"LLMRequestsChain",
]

@ -3,7 +3,6 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional
import requests
from pydantic import BaseModel, root_validator
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
@ -11,16 +10,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import print_text
from langchain.llms.base import LLM
class RequestsWrapper(BaseModel):
"""Lightweight wrapper to partial out everything except the url to hit."""
headers: Optional[dict] = None
def run(self, url: str) -> str:
"""Hit the URL and return the text."""
return requests.get(url, headers=self.headers).text
from langchain.requests import RequestsWrapper
class APIChain(Chain, BaseModel):

@ -0,0 +1,73 @@
"""Chain that hits a URL and then uses an LLM to parse results."""
from __future__ import annotations
from typing import Dict, List
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains import LLMChain
from langchain.chains.base import Chain
from langchain.requests import RequestsWrapper
DEFAULT_HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501
}
class LLMRequestsChain(Chain, BaseModel):
"""Chain that hits a URL and then uses an LLM to parse results."""
llm_chain: LLMChain
requests_wrapper: RequestsWrapper = Field(default_factory=RequestsWrapper)
text_length: int = 8000
requests_key: str = "requests_result" #: :meta private:
input_key: str = "url" #: :meta private:
output_key: str = "output" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
from bs4 import BeautifulSoup # noqa: F401
except ImportError:
raise ValueError(
"Could not import bs4 python package. "
"Please it install it with `pip install bs4`."
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
from bs4 import BeautifulSoup
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
url = inputs[self.input_key]
res = self.requests_wrapper.run(url)
# extract the text from the html
soup = BeautifulSoup(res, "html.parser")
other_keys[self.requests_key] = soup.get_text()[: self.text_length]
result = self.llm_chain.predict(**other_keys)
return {self.output_key: result}

@ -0,0 +1,15 @@
"""Lightweight wrapper around request library."""
from typing import Optional
import requests
from pydantic import BaseModel
class RequestsWrapper(BaseModel):
"""Lightweight wrapper to partial out everything except the url to hit."""
headers: Optional[dict] = None
def run(self, url: str) -> str:
"""Hit the URL and return the text."""
return requests.get(url, headers=self.headers).text

59
poetry.lock generated

@ -56,7 +56,7 @@ tests = ["pytest"]
[[package]]
name = "asttokens"
version = "2.2.0"
version = "2.2.1"
description = "Annotate AST trees with source code positions"
category = "dev"
optional = false
@ -355,15 +355,15 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc
[[package]]
name = "filelock"
version = "3.8.0"
version = "3.8.2"
description = "A platform independent file lock."
category = "main"
optional = true
python-versions = ">=3.7"
[package.extras]
docs = ["furo (>=2022.6.21)", "sphinx (>=5.1.1)", "sphinx-autodoc-typehints (>=1.19.1)"]
testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pytest-cov (>=3)", "pytest-timeout (>=2.1)"]
docs = ["furo (>=2022.9.29)", "sphinx (>=5.3)", "sphinx-autodoc-typehints (>=1.19.5)"]
testing = ["covdefaults (>=2.2.2)", "coverage (>=6.5)", "pytest (>=7.2)", "pytest-cov (>=4)", "pytest-timeout (>=2.1)"]
[[package]]
name = "flake8"
@ -650,7 +650,7 @@ qtconsole = "*"
[[package]]
name = "jupyter-client"
version = "7.4.7"
version = "7.4.8"
description = "Jupyter protocol implementation and client libraries"
category = "dev"
optional = false
@ -903,7 +903,7 @@ test = ["ipykernel", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>=
[[package]]
name = "nbconvert"
version = "7.2.5"
version = "7.2.6"
description = "Converting Jupyter Notebooks"
category = "dev"
optional = false
@ -928,12 +928,12 @@ tinycss2 = "*"
traitlets = ">=5.0"
[package.extras]
all = ["ipykernel", "ipython", "ipywidgets (>=7)", "myst-parser", "nbsphinx (>=0.2.12)", "pre-commit", "pyppeteer (>=1,<1.1)", "pyqtwebengine (>=5.15)", "pytest", "pytest-cov", "pytest-dependency", "sphinx (==5.0.2)", "sphinx-rtd-theme", "tornado (>=6.1)"]
docs = ["ipython", "myst-parser", "nbsphinx (>=0.2.12)", "sphinx (==5.0.2)", "sphinx-rtd-theme"]
qtpdf = ["pyqtwebengine (>=5.15)"]
all = ["nbconvert[docs,qtpdf,serve,test,webpdf]"]
docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sphinx-theme", "sphinx (==5.0.2)"]
qtpdf = ["nbconvert[qtpng]"]
qtpng = ["pyqtwebengine (>=5.15)"]
serve = ["tornado (>=6.1)"]
test = ["ipykernel", "ipywidgets (>=7)", "pre-commit", "pyppeteer (>=1,<1.1)", "pytest", "pytest-cov", "pytest-dependency"]
test = ["ipykernel", "ipywidgets (>=7)", "pre-commit", "pyppeteer (>=1,<1.1)", "pytest", "pytest-dependency"]
webpdf = ["pyppeteer (>=1,<1.1)"]
[[package]]
@ -1452,15 +1452,14 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"]
[[package]]
name = "redis"
version = "4.3.5"
version = "4.4.0"
description = "Python client for Redis database and key-value store"
category = "main"
optional = true
python-versions = ">=3.6"
python-versions = ">=3.7"
[package.dependencies]
async-timeout = ">=4.0.2"
packaging = ">=20.4"
[package.extras]
hiredis = ["hiredis (>=1.0.0)"]
@ -1711,7 +1710,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
[[package]]
name = "terminado"
version = "0.17.0"
version = "0.17.1"
description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library."
category = "dev"
optional = false
@ -1723,7 +1722,7 @@ pywinpty = {version = ">=1.1.0", markers = "os_name == \"nt\""}
tornado = ">=6.1.0"
[package.extras]
docs = ["pydata-sphinx-theme", "sphinx"]
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"]
[[package]]
@ -1946,7 +1945,7 @@ types-urllib3 = "<1.27"
name = "types-toml"
version = "0.10.8.1"
description = "Typing stubs for toml"
category = "main"
category = "dev"
optional = false
python-versions = "*"
@ -2049,13 +2048,13 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker
testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
[extras]
all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia"]
all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4"]
llms = ["manifest-ml"]
[metadata]
lock-version = "1.1"
python-versions = ">=3.8.1,<4.0"
content-hash = "4c7e2033677ad4d5924096db6fed10f13be8f12e9098891e1f105e8fb5e72cee"
content-hash = "7f44c3b23d4fa30e192ec0f0a9218bcd646bb48dd64a813ee1bb7d61cbe3a5b2"
[metadata.files]
anyio = [
@ -2094,8 +2093,8 @@ argon2-cffi-bindings = [
{file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5e00316dabdaea0b2dd82d141cc66889ced0cdcbfa599e8b471cf22c620c329a"},
]
asttokens = [
{file = "asttokens-2.2.0-py2.py3-none-any.whl", hash = "sha256:c56caef774a929923696f09ceea0eadcb95c94b30e8ee4f9fc4f5867096caaeb"},
{file = "asttokens-2.2.0.tar.gz", hash = "sha256:e27b1f115daebfafd4d1826fc75f9a72f0b74bd3ae4ee4d9380406d74d35e52c"},
{file = "asttokens-2.2.1-py2.py3-none-any.whl", hash = "sha256:6b0ac9e93fb0335014d382b8fa9b3afa7df546984258005da0b9e7095b3deb1c"},
{file = "asttokens-2.2.1.tar.gz", hash = "sha256:4622110b2a6f30b77e1473affaa97e711bc2f07d3f10848420ff1898edbe94f3"},
]
async-timeout = [
{file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"},
@ -2374,8 +2373,8 @@ fastjsonschema = [
{file = "fastjsonschema-2.16.2.tar.gz", hash = "sha256:01e366f25d9047816fe3d288cbfc3e10541daf0af2044763f3d0ade42476da18"},
]
filelock = [
{file = "filelock-3.8.0-py3-none-any.whl", hash = "sha256:617eb4e5eedc82fc5f47b6d61e4d11cb837c56cb4544e39081099fa17ad109d4"},
{file = "filelock-3.8.0.tar.gz", hash = "sha256:55447caa666f2198c5b6b13a26d2084d26fa5b115c00d065664b2124680c4edc"},
{file = "filelock-3.8.2-py3-none-any.whl", hash = "sha256:8df285554452285f79c035efb0c861eb33a4bcfa5b7a137016e32e6a90f9792c"},
{file = "filelock-3.8.2.tar.gz", hash = "sha256:7565f628ea56bfcd8e54e42bdc55da899c85c1abfe1b5bcfd147e9188cebb3b2"},
]
flake8 = [
{file = "flake8-6.0.0-py2.py3-none-any.whl", hash = "sha256:3833794e27ff64ea4e9cf5d410082a8b97ff1a06c16aa3d2027339cd0f1195c7"},
@ -2509,8 +2508,8 @@ jupyter = [
{file = "jupyter-1.0.0.zip", hash = "sha256:3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7"},
]
jupyter-client = [
{file = "jupyter_client-7.4.7-py3-none-any.whl", hash = "sha256:df56ae23b8e1da1b66f89dee1368e948b24a7f780fa822c5735187589fc4c157"},
{file = "jupyter_client-7.4.7.tar.gz", hash = "sha256:330f6b627e0b4bf2f54a3a0dd9e4a22d2b649c8518168afedce2c96a1ceb2860"},
{file = "jupyter_client-7.4.8-py3-none-any.whl", hash = "sha256:d4a67ae86ee014bcb96bd8190714f6af921f2b0f52f4208b086aa5acfd9f8d65"},
{file = "jupyter_client-7.4.8.tar.gz", hash = "sha256:109a3c33b62a9cf65aa8325850a0999a795fac155d9de4f7555aef5f310ee35a"},
]
jupyter-console = [
{file = "jupyter_console-6.4.4-py3-none-any.whl", hash = "sha256:756df7f4f60c986e7bc0172e4493d3830a7e6e75c08750bbe59c0a5403ad6dee"},
@ -2669,8 +2668,8 @@ nbclient = [
{file = "nbclient-0.7.2.tar.gz", hash = "sha256:884a3f4a8c4fc24bb9302f263e0af47d97f0d01fe11ba714171b320c8ac09547"},
]
nbconvert = [
{file = "nbconvert-7.2.5-py3-none-any.whl", hash = "sha256:3e90e108bb5637b5b8a1422af1156af1368b39dd25369ff7faa7dfdcdef18f81"},
{file = "nbconvert-7.2.5.tar.gz", hash = "sha256:8fdc44fd7d9424db7fdc6e1e834a02f6b8620ffb653767388be2f9eb16f84184"},
{file = "nbconvert-7.2.6-py3-none-any.whl", hash = "sha256:f933e82fe48b9a421e4252249f6c0a9a9940dc555642b4729f3f1f526bb16779"},
{file = "nbconvert-7.2.6.tar.gz", hash = "sha256:c9c0e4b26326f7658ebf4cda0acc591b9727c4e3ee3ede962f70c11833b71b40"},
]
nbformat = [
{file = "nbformat-5.7.0-py3-none-any.whl", hash = "sha256:1b05ec2c552c2f1adc745f4eddce1eac8ca9ffd59bb9fd859e827eaa031319f9"},
@ -3094,8 +3093,8 @@ qtpy = [
{file = "QtPy-2.3.0.tar.gz", hash = "sha256:0603c9c83ccc035a4717a12908bf6bc6cb22509827ea2ec0e94c2da7c9ed57c5"},
]
redis = [
{file = "redis-4.3.5-py3-none-any.whl", hash = "sha256:46652271dc7525cd5a9667e5b0ca983c848c75b2b8f7425403395bb8379dcf25"},
{file = "redis-4.3.5.tar.gz", hash = "sha256:30c07511627a4c5c4d970e060000772f323174f75e745a26938319817ead7a12"},
{file = "redis-4.4.0-py3-none-any.whl", hash = "sha256:cae3ee5d1f57d8caf534cd8764edf3163c77e073bdd74b6f54a87ffafdc5e7d9"},
{file = "redis-4.4.0.tar.gz", hash = "sha256:7b8c87d19c45d3f1271b124858d2a5c13160c4e74d4835e28273400fa34d5228"},
]
regex = [
{file = "regex-2022.10.31-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a8ff454ef0bb061e37df03557afda9d785c905dab15584860f982e88be73015f"},
@ -3338,8 +3337,8 @@ stack-data = [
{file = "stack_data-0.6.2.tar.gz", hash = "sha256:32d2dd0376772d01b6cb9fc996f3c8b57a357089dec328ed4b6553d037eaf815"},
]
terminado = [
{file = "terminado-0.17.0-py3-none-any.whl", hash = "sha256:bf6fe52accd06d0661d7611cc73202121ec6ee51e46d8185d489ac074ca457c2"},
{file = "terminado-0.17.0.tar.gz", hash = "sha256:520feaa3aeab8ad64a69ca779be54be9234edb2d0d6567e76c93c2c9a4e6e43f"},
{file = "terminado-0.17.1-py3-none-any.whl", hash = "sha256:8650d44334eba354dd591129ca3124a6ba42c3d5b70df5051b6921d506fdaeae"},
{file = "terminado-0.17.1.tar.gz", hash = "sha256:6ccbbcd3a4f8a25a5ec04991f39a0b8db52dfcd487ea0e578d977e6752380333"},
]
thinc = [
{file = "thinc-8.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5dc6629e4770a13dec34eda3c4d89302f1b5c91ac4663cd53f876a4e761fcc00"},

@ -21,6 +21,7 @@ manifest-ml = {version = "^0.0.1", optional = true}
spacy = {version = "^3", optional = true}
nltk = {version = "^3", optional = true}
transformers = {version = "^4", optional = true}
beautifulsoup4 = {version = "^4", optional = true}
[tool.poetry.group.test.dependencies]
pytest = "^7.2.0"
@ -47,7 +48,7 @@ playwright = "^1.28.0"
[tool.poetry.extras]
llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml"]
all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia"]
all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4"]
[tool.isort]
profile = "black"

@ -5,8 +5,9 @@ import json
import pytest
from langchain import LLMChain
from langchain.chains.api.base import APIChain, RequestsWrapper
from langchain.chains.api.base import APIChain
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.requests import RequestsWrapper
from tests.unit_tests.llms.fake_llm import FakeLLM

Loading…
Cancel
Save