diff --git a/docs/examples/chains/llm_requests.ipynb b/docs/examples/chains/llm_requests.ipynb new file mode 100644 index 00000000..4d2ece40 --- /dev/null +++ b/docs/examples/chains/llm_requests.ipynb @@ -0,0 +1,123 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "dd7ec7af", + "metadata": {}, + "source": [ + "# LLMRequestsChain\n", + "\n", + "Using the request library to get HTML results from a URL and then an LLM to parse results" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "dd8eae75", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.chains import LLMRequestsChain, LLMChain" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "65bf324e", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import PromptTemplate\n", + "\n", + "template = \"\"\"Between >>> and <<< are the raw search result text from google.\n", + "Extract the answer to the question '{query}' or say \"not found\" if the information is not contained.\n", + "Use the format\n", + "Extracted:\n", + ">>> {requests_result} <<<\n", + "Extracted:\"\"\"\n", + "\n", + "PROMPT = PromptTemplate(\n", + " input_variables=[\"query\", \"requests_result\"],\n", + " template=template,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f36ae0d8", + "metadata": {}, + "outputs": [], + "source": [ + "chain = LLMRequestsChain(llm_chain = LLMChain(llm=OpenAI(temperature=0), prompt=PROMPT))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b5d22d9d", + "metadata": {}, + "outputs": [], + "source": [ + "question = \"What are the Three (3) biggest countries, and their respective sizes?\"\n", + "inputs = {\n", + " \"query\": question,\n", + " \"url\": \"https://www.google.com/search?q=\" + question.replace(\" \", \"+\")\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2ea81168", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'query': 'What are the Three (3) biggest countries, and their respective sizes?',\n", + " 'url': 'https://www.google.com/search?q=What+are+the+Three+(3)+biggest+countries,+and+their+respective+sizes?',\n", + " 'output': ' Russia (17,098,242 sq km), Canada (9,984,670 sq km), China (9,706,961 sq km)'}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain(inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db8f2b6d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 1152bc19..8b76937e 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -3,6 +3,7 @@ from langchain.chains.api.base import APIChain from langchain.chains.conversation.base import ConversationChain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.base import LLMMathChain +from langchain.chains.llm_requests import LLMRequestsChain from langchain.chains.pal.base import PALChain from langchain.chains.qa_with_sources.base import QAWithSourcesChain from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain @@ -22,4 +23,5 @@ __all__ = [ "VectorDBQAWithSourcesChain", "PALChain", "APIChain", + "LLMRequestsChain", ] diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 807e8635..0eac3817 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -3,7 +3,6 @@ from __future__ import annotations from typing import Any, Dict, List, Optional -import requests from pydantic import BaseModel, root_validator from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT @@ -11,16 +10,7 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import print_text from langchain.llms.base import LLM - - -class RequestsWrapper(BaseModel): - """Lightweight wrapper to partial out everything except the url to hit.""" - - headers: Optional[dict] = None - - def run(self, url: str) -> str: - """Hit the URL and return the text.""" - return requests.get(url, headers=self.headers).text +from langchain.requests import RequestsWrapper class APIChain(Chain, BaseModel): diff --git a/langchain/chains/llm_requests.py b/langchain/chains/llm_requests.py new file mode 100644 index 00000000..ed0efa21 --- /dev/null +++ b/langchain/chains/llm_requests.py @@ -0,0 +1,73 @@ +"""Chain that hits a URL and then uses an LLM to parse results.""" +from __future__ import annotations + +from typing import Dict, List + +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.chains import LLMChain +from langchain.chains.base import Chain +from langchain.requests import RequestsWrapper + +DEFAULT_HEADERS = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501 +} + + +class LLMRequestsChain(Chain, BaseModel): + """Chain that hits a URL and then uses an LLM to parse results.""" + + llm_chain: LLMChain + requests_wrapper: RequestsWrapper = Field(default_factory=RequestsWrapper) + text_length: int = 8000 + requests_key: str = "requests_result" #: :meta private: + input_key: str = "url" #: :meta private: + output_key: str = "output" #: :meta private: + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def input_keys(self) -> List[str]: + """Will be whatever keys the prompt expects. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Will always return text key. + + :meta private: + """ + return [self.output_key] + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + from bs4 import BeautifulSoup # noqa: F401 + + except ImportError: + raise ValueError( + "Could not import bs4 python package. " + "Please it install it with `pip install bs4`." + ) + return values + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + from bs4 import BeautifulSoup + + # Other keys are assumed to be needed for LLM prediction + other_keys = {k: v for k, v in inputs.items() if k != self.input_key} + url = inputs[self.input_key] + res = self.requests_wrapper.run(url) + # extract the text from the html + soup = BeautifulSoup(res, "html.parser") + other_keys[self.requests_key] = soup.get_text()[: self.text_length] + result = self.llm_chain.predict(**other_keys) + return {self.output_key: result} diff --git a/langchain/requests.py b/langchain/requests.py new file mode 100644 index 00000000..70bf5913 --- /dev/null +++ b/langchain/requests.py @@ -0,0 +1,15 @@ +"""Lightweight wrapper around request library.""" +from typing import Optional + +import requests +from pydantic import BaseModel + + +class RequestsWrapper(BaseModel): + """Lightweight wrapper to partial out everything except the url to hit.""" + + headers: Optional[dict] = None + + def run(self, url: str) -> str: + """Hit the URL and return the text.""" + return requests.get(url, headers=self.headers).text diff --git a/poetry.lock b/poetry.lock index f4327980..a6a4d326 100644 --- a/poetry.lock +++ b/poetry.lock @@ -56,7 +56,7 @@ tests = ["pytest"] [[package]] name = "asttokens" -version = "2.2.0" +version = "2.2.1" description = "Annotate AST trees with source code positions" category = "dev" optional = false @@ -355,15 +355,15 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.8.0" +version = "3.8.2" description = "A platform independent file lock." category = "main" optional = true python-versions = ">=3.7" [package.extras] -docs = ["furo (>=2022.6.21)", "sphinx (>=5.1.1)", "sphinx-autodoc-typehints (>=1.19.1)"] -testing = ["covdefaults (>=2.2)", "coverage (>=6.4.2)", "pytest (>=7.1.2)", "pytest-cov (>=3)", "pytest-timeout (>=2.1)"] +docs = ["furo (>=2022.9.29)", "sphinx (>=5.3)", "sphinx-autodoc-typehints (>=1.19.5)"] +testing = ["covdefaults (>=2.2.2)", "coverage (>=6.5)", "pytest (>=7.2)", "pytest-cov (>=4)", "pytest-timeout (>=2.1)"] [[package]] name = "flake8" @@ -650,7 +650,7 @@ qtconsole = "*" [[package]] name = "jupyter-client" -version = "7.4.7" +version = "7.4.8" description = "Jupyter protocol implementation and client libraries" category = "dev" optional = false @@ -903,7 +903,7 @@ test = ["ipykernel", "ipython", "ipywidgets", "nbconvert (>=7.0.0)", "pytest (>= [[package]] name = "nbconvert" -version = "7.2.5" +version = "7.2.6" description = "Converting Jupyter Notebooks" category = "dev" optional = false @@ -928,12 +928,12 @@ tinycss2 = "*" traitlets = ">=5.0" [package.extras] -all = ["ipykernel", "ipython", "ipywidgets (>=7)", "myst-parser", "nbsphinx (>=0.2.12)", "pre-commit", "pyppeteer (>=1,<1.1)", "pyqtwebengine (>=5.15)", "pytest", "pytest-cov", "pytest-dependency", "sphinx (==5.0.2)", "sphinx-rtd-theme", "tornado (>=6.1)"] -docs = ["ipython", "myst-parser", "nbsphinx (>=0.2.12)", "sphinx (==5.0.2)", "sphinx-rtd-theme"] -qtpdf = ["pyqtwebengine (>=5.15)"] +all = ["nbconvert[docs,qtpdf,serve,test,webpdf]"] +docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sphinx-theme", "sphinx (==5.0.2)"] +qtpdf = ["nbconvert[qtpng]"] qtpng = ["pyqtwebengine (>=5.15)"] serve = ["tornado (>=6.1)"] -test = ["ipykernel", "ipywidgets (>=7)", "pre-commit", "pyppeteer (>=1,<1.1)", "pytest", "pytest-cov", "pytest-dependency"] +test = ["ipykernel", "ipywidgets (>=7)", "pre-commit", "pyppeteer (>=1,<1.1)", "pytest", "pytest-dependency"] webpdf = ["pyppeteer (>=1,<1.1)"] [[package]] @@ -1452,15 +1452,14 @@ test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] [[package]] name = "redis" -version = "4.3.5" +version = "4.4.0" description = "Python client for Redis database and key-value store" category = "main" optional = true -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] async-timeout = ">=4.0.2" -packaging = ">=20.4" [package.extras] hiredis = ["hiredis (>=1.0.0)"] @@ -1711,7 +1710,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] [[package]] name = "terminado" -version = "0.17.0" +version = "0.17.1" description = "Tornado websocket backend for the Xterm.js Javascript terminal emulator library." category = "dev" optional = false @@ -1723,7 +1722,7 @@ pywinpty = {version = ">=1.1.0", markers = "os_name == \"nt\""} tornado = ">=6.1.0" [package.extras] -docs = ["pydata-sphinx-theme", "sphinx"] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["pre-commit", "pytest (>=7.0)", "pytest-timeout"] [[package]] @@ -1946,7 +1945,7 @@ types-urllib3 = "<1.27" name = "types-toml" version = "0.10.8.1" description = "Typing stubs for toml" -category = "main" +category = "dev" optional = false python-versions = "*" @@ -2049,13 +2048,13 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia"] +all = ["manifest-ml", "elasticsearch", "faiss-cpu", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4"] llms = ["manifest-ml"] [metadata] lock-version = "1.1" python-versions = ">=3.8.1,<4.0" -content-hash = "4c7e2033677ad4d5924096db6fed10f13be8f12e9098891e1f105e8fb5e72cee" +content-hash = "7f44c3b23d4fa30e192ec0f0a9218bcd646bb48dd64a813ee1bb7d61cbe3a5b2" [metadata.files] anyio = [ @@ -2094,8 +2093,8 @@ argon2-cffi-bindings = [ {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5e00316dabdaea0b2dd82d141cc66889ced0cdcbfa599e8b471cf22c620c329a"}, ] asttokens = [ - {file = "asttokens-2.2.0-py2.py3-none-any.whl", hash = "sha256:c56caef774a929923696f09ceea0eadcb95c94b30e8ee4f9fc4f5867096caaeb"}, - {file = "asttokens-2.2.0.tar.gz", hash = "sha256:e27b1f115daebfafd4d1826fc75f9a72f0b74bd3ae4ee4d9380406d74d35e52c"}, + {file = "asttokens-2.2.1-py2.py3-none-any.whl", hash = "sha256:6b0ac9e93fb0335014d382b8fa9b3afa7df546984258005da0b9e7095b3deb1c"}, + {file = "asttokens-2.2.1.tar.gz", hash = "sha256:4622110b2a6f30b77e1473affaa97e711bc2f07d3f10848420ff1898edbe94f3"}, ] async-timeout = [ {file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"}, @@ -2374,8 +2373,8 @@ fastjsonschema = [ {file = "fastjsonschema-2.16.2.tar.gz", hash = "sha256:01e366f25d9047816fe3d288cbfc3e10541daf0af2044763f3d0ade42476da18"}, ] filelock = [ - {file = "filelock-3.8.0-py3-none-any.whl", hash = "sha256:617eb4e5eedc82fc5f47b6d61e4d11cb837c56cb4544e39081099fa17ad109d4"}, - {file = "filelock-3.8.0.tar.gz", hash = "sha256:55447caa666f2198c5b6b13a26d2084d26fa5b115c00d065664b2124680c4edc"}, + {file = "filelock-3.8.2-py3-none-any.whl", hash = "sha256:8df285554452285f79c035efb0c861eb33a4bcfa5b7a137016e32e6a90f9792c"}, + {file = "filelock-3.8.2.tar.gz", hash = "sha256:7565f628ea56bfcd8e54e42bdc55da899c85c1abfe1b5bcfd147e9188cebb3b2"}, ] flake8 = [ {file = "flake8-6.0.0-py2.py3-none-any.whl", hash = "sha256:3833794e27ff64ea4e9cf5d410082a8b97ff1a06c16aa3d2027339cd0f1195c7"}, @@ -2509,8 +2508,8 @@ jupyter = [ {file = "jupyter-1.0.0.zip", hash = "sha256:3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7"}, ] jupyter-client = [ - {file = "jupyter_client-7.4.7-py3-none-any.whl", hash = "sha256:df56ae23b8e1da1b66f89dee1368e948b24a7f780fa822c5735187589fc4c157"}, - {file = "jupyter_client-7.4.7.tar.gz", hash = "sha256:330f6b627e0b4bf2f54a3a0dd9e4a22d2b649c8518168afedce2c96a1ceb2860"}, + {file = "jupyter_client-7.4.8-py3-none-any.whl", hash = "sha256:d4a67ae86ee014bcb96bd8190714f6af921f2b0f52f4208b086aa5acfd9f8d65"}, + {file = "jupyter_client-7.4.8.tar.gz", hash = "sha256:109a3c33b62a9cf65aa8325850a0999a795fac155d9de4f7555aef5f310ee35a"}, ] jupyter-console = [ {file = "jupyter_console-6.4.4-py3-none-any.whl", hash = "sha256:756df7f4f60c986e7bc0172e4493d3830a7e6e75c08750bbe59c0a5403ad6dee"}, @@ -2669,8 +2668,8 @@ nbclient = [ {file = "nbclient-0.7.2.tar.gz", hash = "sha256:884a3f4a8c4fc24bb9302f263e0af47d97f0d01fe11ba714171b320c8ac09547"}, ] nbconvert = [ - {file = "nbconvert-7.2.5-py3-none-any.whl", hash = "sha256:3e90e108bb5637b5b8a1422af1156af1368b39dd25369ff7faa7dfdcdef18f81"}, - {file = "nbconvert-7.2.5.tar.gz", hash = "sha256:8fdc44fd7d9424db7fdc6e1e834a02f6b8620ffb653767388be2f9eb16f84184"}, + {file = "nbconvert-7.2.6-py3-none-any.whl", hash = "sha256:f933e82fe48b9a421e4252249f6c0a9a9940dc555642b4729f3f1f526bb16779"}, + {file = "nbconvert-7.2.6.tar.gz", hash = "sha256:c9c0e4b26326f7658ebf4cda0acc591b9727c4e3ee3ede962f70c11833b71b40"}, ] nbformat = [ {file = "nbformat-5.7.0-py3-none-any.whl", hash = "sha256:1b05ec2c552c2f1adc745f4eddce1eac8ca9ffd59bb9fd859e827eaa031319f9"}, @@ -3094,8 +3093,8 @@ qtpy = [ {file = "QtPy-2.3.0.tar.gz", hash = "sha256:0603c9c83ccc035a4717a12908bf6bc6cb22509827ea2ec0e94c2da7c9ed57c5"}, ] redis = [ - {file = "redis-4.3.5-py3-none-any.whl", hash = "sha256:46652271dc7525cd5a9667e5b0ca983c848c75b2b8f7425403395bb8379dcf25"}, - {file = "redis-4.3.5.tar.gz", hash = "sha256:30c07511627a4c5c4d970e060000772f323174f75e745a26938319817ead7a12"}, + {file = "redis-4.4.0-py3-none-any.whl", hash = "sha256:cae3ee5d1f57d8caf534cd8764edf3163c77e073bdd74b6f54a87ffafdc5e7d9"}, + {file = "redis-4.4.0.tar.gz", hash = "sha256:7b8c87d19c45d3f1271b124858d2a5c13160c4e74d4835e28273400fa34d5228"}, ] regex = [ {file = "regex-2022.10.31-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a8ff454ef0bb061e37df03557afda9d785c905dab15584860f982e88be73015f"}, @@ -3338,8 +3337,8 @@ stack-data = [ {file = "stack_data-0.6.2.tar.gz", hash = "sha256:32d2dd0376772d01b6cb9fc996f3c8b57a357089dec328ed4b6553d037eaf815"}, ] terminado = [ - {file = "terminado-0.17.0-py3-none-any.whl", hash = "sha256:bf6fe52accd06d0661d7611cc73202121ec6ee51e46d8185d489ac074ca457c2"}, - {file = "terminado-0.17.0.tar.gz", hash = "sha256:520feaa3aeab8ad64a69ca779be54be9234edb2d0d6567e76c93c2c9a4e6e43f"}, + {file = "terminado-0.17.1-py3-none-any.whl", hash = "sha256:8650d44334eba354dd591129ca3124a6ba42c3d5b70df5051b6921d506fdaeae"}, + {file = "terminado-0.17.1.tar.gz", hash = "sha256:6ccbbcd3a4f8a25a5ec04991f39a0b8db52dfcd487ea0e578d977e6752380333"}, ] thinc = [ {file = "thinc-8.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5dc6629e4770a13dec34eda3c4d89302f1b5c91ac4663cd53f876a4e761fcc00"}, diff --git a/pyproject.toml b/pyproject.toml index e286111b..1e4c5a6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ manifest-ml = {version = "^0.0.1", optional = true} spacy = {version = "^3", optional = true} nltk = {version = "^3", optional = true} transformers = {version = "^4", optional = true} +beautifulsoup4 = {version = "^4", optional = true} [tool.poetry.group.test.dependencies] pytest = "^7.2.0" @@ -47,7 +48,7 @@ playwright = "^1.28.0" [tool.poetry.extras] llms = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml"] -all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia"] +all = ["cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "elasticsearch", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4"] [tool.isort] profile = "black" diff --git a/tests/unit_tests/chains/test_api.py b/tests/unit_tests/chains/test_api.py index 9915ef22..7a643632 100644 --- a/tests/unit_tests/chains/test_api.py +++ b/tests/unit_tests/chains/test_api.py @@ -5,8 +5,9 @@ import json import pytest from langchain import LLMChain -from langchain.chains.api.base import APIChain, RequestsWrapper +from langchain.chains.api.base import APIChain from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT +from langchain.requests import RequestsWrapper from tests.unit_tests.llms.fake_llm import FakeLLM