fix recursive loader (#10752)

maintain same base url throughout recursion, yield initial page, fixing
recursion depth tracking
pull/10842/head
Bagatur 11 months ago committed by GitHub
parent 276125a33b
commit 96a9c27116
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,12 +1,51 @@
from __future__ import annotations
import asyncio
import logging
import re
from typing import Callable, Iterator, List, Optional, Set, Union
from urllib.parse import urljoin, urlparse
from typing import (
TYPE_CHECKING,
Callable,
Iterator,
List,
Optional,
Sequence,
Set,
Union,
)
import requests
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from langchain.utils.html import extract_sub_links
if TYPE_CHECKING:
import aiohttp
logger = logging.getLogger(__name__)
def _metadata_extractor(raw_html: str, url: str) -> dict:
"""Extract metadata from raw html using BeautifulSoup."""
metadata = {"source": url}
try:
from bs4 import BeautifulSoup
except ImportError:
logger.warning(
"The bs4 package is required for default metadata extraction. "
"Please install it with `pip install bs4`."
)
return metadata
soup = BeautifulSoup(raw_html, "html.parser")
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get("content", None)
if html := soup.find("html"):
metadata["language"] = html.get("lang", None)
return metadata
class RecursiveUrlLoader(BaseLoader):
@ -15,173 +54,106 @@ class RecursiveUrlLoader(BaseLoader):
def __init__(
self,
url: str,
max_depth: Optional[int] = None,
max_depth: Optional[int] = 2,
use_async: Optional[bool] = None,
extractor: Optional[Callable[[str], str]] = None,
exclude_dirs: Optional[str] = None,
timeout: Optional[int] = None,
prevent_outside: Optional[bool] = None,
metadata_extractor: Optional[Callable[[str, str], str]] = None,
exclude_dirs: Optional[Sequence[str]] = (),
timeout: Optional[int] = 10,
prevent_outside: Optional[bool] = True,
link_regex: Union[str, re.Pattern, None] = None,
headers: Optional[dict] = None,
) -> None:
"""Initialize with URL to crawl and any subdirectories to exclude.
Args:
url: The URL to crawl.
exclude_dirs: A list of subdirectories to exclude.
use_async: Whether to use asynchronous loading,
if use_async is true, this function will not be lazy,
but it will still work in the expected way, just not lazy.
extractor: A function to extract the text from the html,
when extract function returns empty string, the document will be ignored.
max_depth: The max depth of the recursive loading.
timeout: The timeout for the requests, in the unit of seconds.
use_async: Whether to use asynchronous loading.
If True, this function will not be lazy, but it will still work in the
expected way, just not lazy.
extractor: A function to extract document contents from raw html.
When extract function returns an empty string, the document is
ignored.
metadata_extractor: A function to extract metadata from raw html and the
source url (args in that order). Default extractor will attempt
to use BeautifulSoup4 to extract the title, description and language
of the page.
exclude_dirs: A list of subdirectories to exclude.
timeout: The timeout for the requests, in the unit of seconds. If None then
connection will not timeout.
prevent_outside: If True, prevent loading from urls which are not children
of the root url.
link_regex: Regex for extracting sub-links from the raw html of a web page.
"""
self.url = url
self.exclude_dirs = exclude_dirs
self.max_depth = max_depth if max_depth is not None else 2
self.use_async = use_async if use_async is not None else False
self.extractor = extractor if extractor is not None else lambda x: x
self.max_depth = max_depth if max_depth is not None else 2
self.timeout = timeout if timeout is not None else 10
self.prevent_outside = prevent_outside if prevent_outside is not None else True
def _get_sub_links(self, raw_html: str, base_url: str) -> List[str]:
"""This function extracts all the links from the raw html,
and convert them into absolute paths.
Args:
raw_html (str): original html
base_url (str): the base url of the html
Returns:
List[str]: sub links
"""
# Get all links that are relative to the root of the website
all_links = re.findall(r"href=[\"\'](.*?)[\"\']", raw_html)
absolute_paths = []
invalid_prefixes = ("javascript:", "mailto:", "#")
invalid_suffixes = (
".css",
".js",
".ico",
".png",
".jpg",
".jpeg",
".gif",
".svg",
)
# Process the links
for link in all_links:
# Ignore blacklisted patterns
# like javascript: or mailto:, files of svg, ico, css, js
if link.startswith(invalid_prefixes) or link.endswith(invalid_suffixes):
continue
# Some may be absolute links like https://to/path
if link.startswith("http"):
if (not self.prevent_outside) or (
self.prevent_outside and link.startswith(base_url)
):
absolute_paths.append(link)
else:
absolute_paths.append(urljoin(base_url, link))
# Some may be relative links like /to/path
if link.startswith("/") and not link.startswith("//"):
absolute_paths.append(urljoin(base_url, link))
continue
# Some may have omitted the protocol like //to/path
if link.startswith("//"):
absolute_paths.append(f"{urlparse(base_url).scheme}:{link}")
continue
# Remove duplicates
# also do another filter to prevent outside links
absolute_paths = list(
set(
[
path
for path in absolute_paths
if not self.prevent_outside
or path.startswith(base_url)
and path != base_url
]
)
self.metadata_extractor = (
metadata_extractor
if metadata_extractor is not None
else _metadata_extractor
)
return absolute_paths
def _gen_metadata(self, raw_html: str, url: str) -> dict:
"""Build metadata from BeautifulSoup output."""
try:
from bs4 import BeautifulSoup
except ImportError:
print("The bs4 package is required for the RecursiveUrlLoader.")
print("Please install it with `pip install bs4`.")
metadata = {"source": url}
soup = BeautifulSoup(raw_html, "html.parser")
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get("content", None)
if html := soup.find("html"):
metadata["language"] = html.get("lang", None)
return metadata
self.exclude_dirs = exclude_dirs if exclude_dirs is not None else ()
self.timeout = timeout
self.prevent_outside = prevent_outside if prevent_outside is not None else True
self.link_regex = link_regex
self._lock = asyncio.Lock() if self.use_async else None
self.headers = headers
def _get_child_links_recursive(
self, url: str, visited: Optional[Set[str]] = None, depth: int = 0
self, url: str, visited: Set[str], *, depth: int = 0
) -> Iterator[Document]:
"""Recursively get all child links starting with the path of the input URL.
Args:
url: The URL to crawl.
visited: A set of visited URLs.
depth: Current depth of recursion. Stop when depth >= max_depth.
"""
if depth > self.max_depth:
return []
# Add a trailing slash if not present
if not url.endswith("/"):
url += "/"
# Exclude the root and parent from a list
visited = set() if visited is None else visited
if depth >= self.max_depth:
return
# Exclude the links that start with any of the excluded directories
if self.exclude_dirs and any(
url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs
):
return []
if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs):
return
# Get all links that can be accessed from the current URL
try:
response = requests.get(url, timeout=self.timeout)
response = requests.get(url, timeout=self.timeout, headers=self.headers)
except Exception:
return []
absolute_paths = self._get_sub_links(response.text, url)
logger.warning(f"Unable to load from {url}")
return
content = self.extractor(response.text)
if content:
yield Document(
page_content=content,
metadata=self.metadata_extractor(response.text, url),
)
visited.add(url)
# Store the visited links and recursively visit the children
for link in absolute_paths:
sub_links = extract_sub_links(
response.text,
self.url,
pattern=self.link_regex,
prevent_outside=self.prevent_outside,
)
for link in sub_links:
# Check all unvisited links
if link not in visited:
visited.add(link)
try:
response = requests.get(link)
text = response.text
except Exception:
# unreachable link, so just ignore it
continue
loaded_link = Document(
page_content=self.extractor(text),
metadata=self._gen_metadata(text, link),
yield from self._get_child_links_recursive(
link, visited, depth=depth + 1
)
yield loaded_link
# If the link is a directory (w/ children) then visit it
if link.endswith("/"):
yield from self._get_child_links_recursive(link, visited, depth + 1)
return []
async def _async_get_child_links_recursive(
self, url: str, visited: Optional[Set[str]] = None, depth: int = 0
self,
url: str,
visited: Set[str],
*,
session: Optional[aiohttp.ClientSession] = None,
depth: int = 0,
) -> List[Document]:
"""Recursively get all child links starting with the path of the input URL.
@ -193,117 +165,87 @@ class RecursiveUrlLoader(BaseLoader):
try:
import aiohttp
except ImportError:
print("The aiohttp package is required for the RecursiveUrlLoader.")
print("Please install it with `pip install aiohttp`.")
if depth > self.max_depth:
raise ImportError(
"The aiohttp package is required for the RecursiveUrlLoader. "
"Please install it with `pip install aiohttp`."
)
if depth >= self.max_depth:
return []
# Add a trailing slash if not present
if not url.endswith("/"):
url += "/"
# Exclude the root and parent from a list
visited = set() if visited is None else visited
# Exclude the links that start with any of the excluded directories
if self.exclude_dirs and any(
url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs
):
if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs):
return []
# Disable SSL verification because websites may have invalid SSL certificates,
# but won't cause any security issues for us.
async with aiohttp.ClientSession(
close_session = session is None
session = session or aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl=False),
timeout=aiohttp.ClientTimeout(self.timeout),
) as session:
# Some url may be invalid, so catch the exception
response: aiohttp.ClientResponse
try:
response = await session.get(url)
timeout=aiohttp.ClientTimeout(total=self.timeout),
headers=self.headers,
)
try:
async with session.get(url) as response:
text = await response.text()
except aiohttp.client_exceptions.InvalidURL:
return []
# There may be some other exceptions, so catch them,
# we don't want to stop the whole process
except Exception:
return []
absolute_paths = self._get_sub_links(text, url)
# Worker will be only called within the current function
# Worker function will process the link
# then recursively call get_child_links_recursive to process the children
async def worker(link: str) -> Union[Document, None]:
try:
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl=False),
timeout=aiohttp.ClientTimeout(self.timeout),
) as session:
response = await session.get(link)
text = await response.text()
extracted = self.extractor(text)
if len(extracted) > 0:
return Document(
page_content=extracted,
metadata=self._gen_metadata(text, link),
)
else:
return None
# Despite the fact that we have filtered some links,
# there may still be some invalid links, so catch the exception
except aiohttp.client_exceptions.InvalidURL:
return None
# There may be some other exceptions, so catch them,
# we don't want to stop the whole process
except Exception:
# print(e)
return None
# The coroutines that will be executed
tasks = []
# Generate the tasks
for link in absolute_paths:
# Check all unvisited links
if link not in visited:
visited.add(link)
tasks.append(worker(link))
# Get the not None results
results = list(
filter(lambda x: x is not None, await asyncio.gather(*tasks))
async with self._lock: # type: ignore
visited.add(url)
except (aiohttp.client_exceptions.InvalidURL, Exception) as e:
logger.warning(
f"Unable to load {url}. Received error {e} of type "
f"{e.__class__.__name__}"
)
return []
results = []
content = self.extractor(text)
if content:
results.append(
Document(
page_content=content,
metadata=self.metadata_extractor(text, url),
)
)
if depth < self.max_depth - 1:
sub_links = extract_sub_links(
text,
self.url,
pattern=self.link_regex,
prevent_outside=self.prevent_outside,
)
# Recursively call the function to get the children of the children
sub_tasks = []
for link in absolute_paths:
sub_tasks.append(
self._async_get_child_links_recursive(link, visited, depth + 1)
)
# sub_tasks returns coroutines of list,
# so we need to flatten the list await asyncio.gather(*sub_tasks)
flattened = []
async with self._lock: # type: ignore
to_visit = set(sub_links).difference(visited)
for link in to_visit:
sub_tasks.append(
self._async_get_child_links_recursive(
link, visited, session=session, depth=depth + 1
)
)
next_results = await asyncio.gather(*sub_tasks)
for sub_result in next_results:
if isinstance(sub_result, Exception):
if isinstance(sub_result, Exception) or sub_result is None:
# We don't want to stop the whole process, so just ignore it
# Not standard html format or invalid url or 404 may cause this
# But we can't do anything about it.
# Not standard html format or invalid url or 404 may cause this.
continue
if sub_result is not None:
flattened += sub_result
results += flattened
return list(filter(lambda x: x is not None, results))
# locking not fully working, temporary hack to ensure deduplication
results += [r for r in sub_result if r not in results]
if close_session:
await session.close()
return results
def lazy_load(self) -> Iterator[Document]:
"""Lazy load web pages.
When use_async is True, this function will not be lazy,
but it will still work in the expected way, just not lazy."""
visited: Set[str] = set()
if self.use_async:
results = asyncio.run(self._async_get_child_links_recursive(self.url))
if results is None:
return iter([])
else:
return iter(results)
results = asyncio.run(
self._async_get_child_links_recursive(self.url, visited)
)
return iter(results or [])
else:
return self._get_child_links_recursive(self.url)
return self._get_child_links_recursive(self.url, visited)
def load(self) -> List[Document]:
"""Load web pages."""

@ -0,0 +1,69 @@
import re
from typing import List, Union
from urllib.parse import urljoin, urlparse
PREFIXES_TO_IGNORE = ("javascript:", "mailto:", "#")
SUFFIXES_TO_IGNORE = (
".css",
".js",
".ico",
".png",
".jpg",
".jpeg",
".gif",
".svg",
".csv",
".bz2",
".zip",
".epub",
)
SUFFIXES_TO_IGNORE_REGEX = (
"(?!" + "|".join([re.escape(s) + "[\#'\"]" for s in SUFFIXES_TO_IGNORE]) + ")"
)
PREFIXES_TO_IGNORE_REGEX = (
"(?!" + "|".join([re.escape(s) for s in PREFIXES_TO_IGNORE]) + ")"
)
DEFAULT_LINK_REGEX = (
f"href=[\"']{PREFIXES_TO_IGNORE_REGEX}((?:{SUFFIXES_TO_IGNORE_REGEX}.)*?)[\#'\"]"
)
def find_all_links(
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
) -> List[str]:
pattern = pattern or DEFAULT_LINK_REGEX
return list(set(re.findall(pattern, raw_html)))
def extract_sub_links(
raw_html: str,
base_url: str,
*,
pattern: Union[str, re.Pattern, None] = None,
prevent_outside: bool = True,
) -> List[str]:
"""Extract all links from a raw html string and convert into absolute paths.
Args:
raw_html: original html
base_url: the base url of the html
pattern: Regex to use for extracting links from raw html.
prevent_outside: If True, ignore external links which are not children
of the base url.
Returns:
List[str]: sub links
"""
all_links = find_all_links(raw_html, pattern=pattern)
absolute_paths = set()
for link in all_links:
# Some may be absolute links like https://to/path
if link.startswith("http"):
if not prevent_outside or link.startswith(base_url):
absolute_paths.add(link)
# Some may have omitted the protocol like //to/path
elif link.startswith("//"):
absolute_paths.add(f"{urlparse(base_url).scheme}:{link}")
else:
absolute_paths.add(urljoin(base_url, link))
return list(absolute_paths)

@ -1,30 +1,61 @@
import pytest as pytest
from langchain.document_loaders.recursive_url_loader import RecursiveUrlLoader
@pytest.mark.asyncio
def test_async_recursive_url_loader() -> None:
url = "https://docs.python.org/3.9/"
loader = RecursiveUrlLoader(
url=url, extractor=lambda _: "placeholder", use_async=True, max_depth=1
url,
extractor=lambda _: "placeholder",
use_async=True,
max_depth=3,
timeout=None,
)
docs = loader.load()
assert len(docs) == 24
assert len(docs) == 1024
assert docs[0].page_content == "placeholder"
@pytest.mark.asyncio
def test_async_recursive_url_loader_deterministic() -> None:
url = "https://docs.python.org/3.9/"
loader = RecursiveUrlLoader(
url,
use_async=True,
max_depth=3,
timeout=None,
)
docs = sorted(loader.load(), key=lambda d: d.metadata["source"])
docs_2 = sorted(loader.load(), key=lambda d: d.metadata["source"])
assert docs == docs_2
def test_sync_recursive_url_loader() -> None:
url = "https://docs.python.org/3.9/"
loader = RecursiveUrlLoader(
url=url, extractor=lambda _: "placeholder", use_async=False, max_depth=1
url, extractor=lambda _: "placeholder", use_async=False, max_depth=2
)
docs = loader.load()
assert len(docs) == 24
assert len(docs) == 27
assert docs[0].page_content == "placeholder"
@pytest.mark.asyncio
def test_sync_async_equivalent() -> None:
url = "https://docs.python.org/3.9/"
loader = RecursiveUrlLoader(url, use_async=False, max_depth=2)
async_loader = RecursiveUrlLoader(url, use_async=False, max_depth=2)
docs = sorted(loader.load(), key=lambda d: d.metadata["source"])
async_docs = sorted(async_loader.load(), key=lambda d: d.metadata["source"])
assert docs == async_docs
def test_loading_invalid_url() -> None:
url = "https://this.url.is.invalid/this/is/a/test"
loader = RecursiveUrlLoader(
url=url, max_depth=1, extractor=lambda _: "placeholder", use_async=False
url, max_depth=1, extractor=lambda _: "placeholder", use_async=False
)
docs = loader.load()
assert len(docs) == 0

@ -0,0 +1,109 @@
from langchain.utils.html import (
PREFIXES_TO_IGNORE,
SUFFIXES_TO_IGNORE,
extract_sub_links,
find_all_links,
)
def test_find_all_links_none() -> None:
html = "<span>Hello world</span>"
actual = find_all_links(html)
assert actual == []
def test_find_all_links_single() -> None:
htmls = [
"href='foobar.com'",
'href="foobar.com"',
'<div><a class="blah" href="foobar.com">hullo</a></div>',
]
actual = [find_all_links(html) for html in htmls]
assert actual == [["foobar.com"]] * 3
def test_find_all_links_multiple() -> None:
html = (
'<div><a class="blah" href="https://foobar.com">hullo</a></div>'
'<div><a class="bleh" href="/baz/cool">buhbye</a></div>'
)
actual = find_all_links(html)
assert sorted(actual) == [
"/baz/cool",
"https://foobar.com",
]
def test_find_all_links_ignore_suffix() -> None:
html = 'href="foobar{suffix}"'
for suffix in SUFFIXES_TO_IGNORE:
actual = find_all_links(html.format(suffix=suffix))
assert actual == []
# Don't ignore if pattern doesn't occur at end of link.
html = 'href="foobar{suffix}more"'
for suffix in SUFFIXES_TO_IGNORE:
actual = find_all_links(html.format(suffix=suffix))
assert actual == [f"foobar{suffix}more"]
def test_find_all_links_ignore_prefix() -> None:
html = 'href="{prefix}foobar"'
for prefix in PREFIXES_TO_IGNORE:
actual = find_all_links(html.format(prefix=prefix))
assert actual == []
# Don't ignore if pattern doesn't occur at beginning of link.
html = 'href="foobar{prefix}more"'
for prefix in PREFIXES_TO_IGNORE:
# Pound signs are split on when not prefixes.
if prefix == "#":
continue
actual = find_all_links(html.format(prefix=prefix))
assert actual == [f"foobar{prefix}more"]
def test_find_all_links_drop_fragment() -> None:
html = 'href="foobar.com/woah#section_one"'
actual = find_all_links(html)
assert actual == ["foobar.com/woah"]
def test_extract_sub_links() -> None:
html = (
'<a href="https://foobar.com">one</a>'
'<a href="http://baz.net">two</a>'
'<a href="//foobar.com/hello">three</a>'
'<a href="/how/are/you/doing">four</a>'
)
expected = sorted(
[
"https://foobar.com",
"https://foobar.com/hello",
"https://foobar.com/how/are/you/doing",
]
)
actual = sorted(extract_sub_links(html, "https://foobar.com"))
assert actual == expected
actual = sorted(extract_sub_links(html, "https://foobar.com/hello"))
expected = sorted(
[
"https://foobar.com/hello",
"https://foobar.com/how/are/you/doing",
]
)
assert actual == expected
actual = sorted(
extract_sub_links(html, "https://foobar.com/hello", prevent_outside=False)
)
expected = sorted(
[
"https://foobar.com",
"http://baz.net",
"https://foobar.com/hello",
"https://foobar.com/how/are/you/doing",
]
)
assert actual == expected
Loading…
Cancel
Save