From 96a9c271167c13173747ce4c98bb621b83c4bf83 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 20 Sep 2023 08:16:54 -0700 Subject: [PATCH] fix recursive loader (#10752) maintain same base url throughout recursion, yield initial page, fixing recursion depth tracking --- .../document_loaders/recursive_url_loader.py | 386 ++++++++---------- libs/langchain/langchain/utils/html.py | 69 ++++ .../test_recursive_url_loader.py | 41 +- .../tests/unit_tests/utils/test_html.py | 109 +++++ 4 files changed, 378 insertions(+), 227 deletions(-) create mode 100644 libs/langchain/langchain/utils/html.py create mode 100644 libs/langchain/tests/unit_tests/utils/test_html.py diff --git a/libs/langchain/langchain/document_loaders/recursive_url_loader.py b/libs/langchain/langchain/document_loaders/recursive_url_loader.py index 61b9c7032e..6a73e515a8 100644 --- a/libs/langchain/langchain/document_loaders/recursive_url_loader.py +++ b/libs/langchain/langchain/document_loaders/recursive_url_loader.py @@ -1,12 +1,51 @@ +from __future__ import annotations + import asyncio +import logging import re -from typing import Callable, Iterator, List, Optional, Set, Union -from urllib.parse import urljoin, urlparse +from typing import ( + TYPE_CHECKING, + Callable, + Iterator, + List, + Optional, + Sequence, + Set, + Union, +) import requests from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader +from langchain.utils.html import extract_sub_links + +if TYPE_CHECKING: + import aiohttp + +logger = logging.getLogger(__name__) + + +def _metadata_extractor(raw_html: str, url: str) -> dict: + """Extract metadata from raw html using BeautifulSoup.""" + metadata = {"source": url} + + try: + from bs4 import BeautifulSoup + except ImportError: + logger.warning( + "The bs4 package is required for default metadata extraction. " + "Please install it with `pip install bs4`." + ) + return metadata + soup = BeautifulSoup(raw_html, "html.parser") + if title := soup.find("title"): + metadata["title"] = title.get_text() + if description := soup.find("meta", attrs={"name": "description"}): + metadata["description"] = description.get("content", None) + if html := soup.find("html"): + metadata["language"] = html.get("lang", None) + return metadata class RecursiveUrlLoader(BaseLoader): @@ -15,173 +54,106 @@ class RecursiveUrlLoader(BaseLoader): def __init__( self, url: str, - max_depth: Optional[int] = None, + max_depth: Optional[int] = 2, use_async: Optional[bool] = None, extractor: Optional[Callable[[str], str]] = None, - exclude_dirs: Optional[str] = None, - timeout: Optional[int] = None, - prevent_outside: Optional[bool] = None, + metadata_extractor: Optional[Callable[[str, str], str]] = None, + exclude_dirs: Optional[Sequence[str]] = (), + timeout: Optional[int] = 10, + prevent_outside: Optional[bool] = True, + link_regex: Union[str, re.Pattern, None] = None, + headers: Optional[dict] = None, ) -> None: """Initialize with URL to crawl and any subdirectories to exclude. Args: url: The URL to crawl. - exclude_dirs: A list of subdirectories to exclude. - use_async: Whether to use asynchronous loading, - if use_async is true, this function will not be lazy, - but it will still work in the expected way, just not lazy. - extractor: A function to extract the text from the html, - when extract function returns empty string, the document will be ignored. max_depth: The max depth of the recursive loading. - timeout: The timeout for the requests, in the unit of seconds. + use_async: Whether to use asynchronous loading. + If True, this function will not be lazy, but it will still work in the + expected way, just not lazy. + extractor: A function to extract document contents from raw html. + When extract function returns an empty string, the document is + ignored. + metadata_extractor: A function to extract metadata from raw html and the + source url (args in that order). Default extractor will attempt + to use BeautifulSoup4 to extract the title, description and language + of the page. + exclude_dirs: A list of subdirectories to exclude. + timeout: The timeout for the requests, in the unit of seconds. If None then + connection will not timeout. + prevent_outside: If True, prevent loading from urls which are not children + of the root url. + link_regex: Regex for extracting sub-links from the raw html of a web page. """ self.url = url - self.exclude_dirs = exclude_dirs + self.max_depth = max_depth if max_depth is not None else 2 self.use_async = use_async if use_async is not None else False self.extractor = extractor if extractor is not None else lambda x: x - self.max_depth = max_depth if max_depth is not None else 2 - self.timeout = timeout if timeout is not None else 10 - self.prevent_outside = prevent_outside if prevent_outside is not None else True - - def _get_sub_links(self, raw_html: str, base_url: str) -> List[str]: - """This function extracts all the links from the raw html, - and convert them into absolute paths. - - Args: - raw_html (str): original html - base_url (str): the base url of the html - - Returns: - List[str]: sub links - """ - # Get all links that are relative to the root of the website - all_links = re.findall(r"href=[\"\'](.*?)[\"\']", raw_html) - absolute_paths = [] - invalid_prefixes = ("javascript:", "mailto:", "#") - invalid_suffixes = ( - ".css", - ".js", - ".ico", - ".png", - ".jpg", - ".jpeg", - ".gif", - ".svg", - ) - # Process the links - for link in all_links: - # Ignore blacklisted patterns - # like javascript: or mailto:, files of svg, ico, css, js - if link.startswith(invalid_prefixes) or link.endswith(invalid_suffixes): - continue - # Some may be absolute links like https://to/path - if link.startswith("http"): - if (not self.prevent_outside) or ( - self.prevent_outside and link.startswith(base_url) - ): - absolute_paths.append(link) - else: - absolute_paths.append(urljoin(base_url, link)) - - # Some may be relative links like /to/path - if link.startswith("/") and not link.startswith("//"): - absolute_paths.append(urljoin(base_url, link)) - continue - # Some may have omitted the protocol like //to/path - if link.startswith("//"): - absolute_paths.append(f"{urlparse(base_url).scheme}:{link}") - continue - # Remove duplicates - # also do another filter to prevent outside links - absolute_paths = list( - set( - [ - path - for path in absolute_paths - if not self.prevent_outside - or path.startswith(base_url) - and path != base_url - ] - ) + self.metadata_extractor = ( + metadata_extractor + if metadata_extractor is not None + else _metadata_extractor ) - - return absolute_paths - - def _gen_metadata(self, raw_html: str, url: str) -> dict: - """Build metadata from BeautifulSoup output.""" - try: - from bs4 import BeautifulSoup - except ImportError: - print("The bs4 package is required for the RecursiveUrlLoader.") - print("Please install it with `pip install bs4`.") - metadata = {"source": url} - soup = BeautifulSoup(raw_html, "html.parser") - if title := soup.find("title"): - metadata["title"] = title.get_text() - if description := soup.find("meta", attrs={"name": "description"}): - metadata["description"] = description.get("content", None) - if html := soup.find("html"): - metadata["language"] = html.get("lang", None) - return metadata + self.exclude_dirs = exclude_dirs if exclude_dirs is not None else () + self.timeout = timeout + self.prevent_outside = prevent_outside if prevent_outside is not None else True + self.link_regex = link_regex + self._lock = asyncio.Lock() if self.use_async else None + self.headers = headers def _get_child_links_recursive( - self, url: str, visited: Optional[Set[str]] = None, depth: int = 0 + self, url: str, visited: Set[str], *, depth: int = 0 ) -> Iterator[Document]: """Recursively get all child links starting with the path of the input URL. Args: url: The URL to crawl. visited: A set of visited URLs. + depth: Current depth of recursion. Stop when depth >= max_depth. """ - if depth > self.max_depth: - return [] - - # Add a trailing slash if not present - if not url.endswith("/"): - url += "/" - - # Exclude the root and parent from a list - visited = set() if visited is None else visited - + if depth >= self.max_depth: + return # Exclude the links that start with any of the excluded directories - if self.exclude_dirs and any( - url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs - ): - return [] + if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs): + return # Get all links that can be accessed from the current URL try: - response = requests.get(url, timeout=self.timeout) + response = requests.get(url, timeout=self.timeout, headers=self.headers) except Exception: - return [] - - absolute_paths = self._get_sub_links(response.text, url) + logger.warning(f"Unable to load from {url}") + return + content = self.extractor(response.text) + if content: + yield Document( + page_content=content, + metadata=self.metadata_extractor(response.text, url), + ) + visited.add(url) # Store the visited links and recursively visit the children - for link in absolute_paths: + sub_links = extract_sub_links( + response.text, + self.url, + pattern=self.link_regex, + prevent_outside=self.prevent_outside, + ) + for link in sub_links: # Check all unvisited links if link not in visited: - visited.add(link) - - try: - response = requests.get(link) - text = response.text - except Exception: - # unreachable link, so just ignore it - continue - loaded_link = Document( - page_content=self.extractor(text), - metadata=self._gen_metadata(text, link), + yield from self._get_child_links_recursive( + link, visited, depth=depth + 1 ) - yield loaded_link - # If the link is a directory (w/ children) then visit it - if link.endswith("/"): - yield from self._get_child_links_recursive(link, visited, depth + 1) - return [] async def _async_get_child_links_recursive( - self, url: str, visited: Optional[Set[str]] = None, depth: int = 0 + self, + url: str, + visited: Set[str], + *, + session: Optional[aiohttp.ClientSession] = None, + depth: int = 0, ) -> List[Document]: """Recursively get all child links starting with the path of the input URL. @@ -193,117 +165,87 @@ class RecursiveUrlLoader(BaseLoader): try: import aiohttp except ImportError: - print("The aiohttp package is required for the RecursiveUrlLoader.") - print("Please install it with `pip install aiohttp`.") - if depth > self.max_depth: + raise ImportError( + "The aiohttp package is required for the RecursiveUrlLoader. " + "Please install it with `pip install aiohttp`." + ) + if depth >= self.max_depth: return [] - # Add a trailing slash if not present - if not url.endswith("/"): - url += "/" - # Exclude the root and parent from a list - visited = set() if visited is None else visited - # Exclude the links that start with any of the excluded directories - if self.exclude_dirs and any( - url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs - ): + if any(url.startswith(exclude_dir) for exclude_dir in self.exclude_dirs): return [] # Disable SSL verification because websites may have invalid SSL certificates, # but won't cause any security issues for us. - async with aiohttp.ClientSession( + close_session = session is None + session = session or aiohttp.ClientSession( connector=aiohttp.TCPConnector(ssl=False), - timeout=aiohttp.ClientTimeout(self.timeout), - ) as session: - # Some url may be invalid, so catch the exception - response: aiohttp.ClientResponse - try: - response = await session.get(url) + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers=self.headers, + ) + try: + async with session.get(url) as response: text = await response.text() - except aiohttp.client_exceptions.InvalidURL: - return [] - # There may be some other exceptions, so catch them, - # we don't want to stop the whole process - except Exception: - return [] - - absolute_paths = self._get_sub_links(text, url) - - # Worker will be only called within the current function - # Worker function will process the link - # then recursively call get_child_links_recursive to process the children - async def worker(link: str) -> Union[Document, None]: - try: - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(ssl=False), - timeout=aiohttp.ClientTimeout(self.timeout), - ) as session: - response = await session.get(link) - text = await response.text() - extracted = self.extractor(text) - if len(extracted) > 0: - return Document( - page_content=extracted, - metadata=self._gen_metadata(text, link), - ) - else: - return None - # Despite the fact that we have filtered some links, - # there may still be some invalid links, so catch the exception - except aiohttp.client_exceptions.InvalidURL: - return None - # There may be some other exceptions, so catch them, - # we don't want to stop the whole process - except Exception: - # print(e) - return None - - # The coroutines that will be executed - tasks = [] - # Generate the tasks - for link in absolute_paths: - # Check all unvisited links - if link not in visited: - visited.add(link) - tasks.append(worker(link)) - # Get the not None results - results = list( - filter(lambda x: x is not None, await asyncio.gather(*tasks)) + async with self._lock: # type: ignore + visited.add(url) + except (aiohttp.client_exceptions.InvalidURL, Exception) as e: + logger.warning( + f"Unable to load {url}. Received error {e} of type " + f"{e.__class__.__name__}" ) + return [] + results = [] + content = self.extractor(text) + if content: + results.append( + Document( + page_content=content, + metadata=self.metadata_extractor(text, url), + ) + ) + if depth < self.max_depth - 1: + sub_links = extract_sub_links( + text, + self.url, + pattern=self.link_regex, + prevent_outside=self.prevent_outside, + ) + # Recursively call the function to get the children of the children sub_tasks = [] - for link in absolute_paths: - sub_tasks.append( - self._async_get_child_links_recursive(link, visited, depth + 1) - ) - # sub_tasks returns coroutines of list, - # so we need to flatten the list await asyncio.gather(*sub_tasks) - flattened = [] + async with self._lock: # type: ignore + to_visit = set(sub_links).difference(visited) + for link in to_visit: + sub_tasks.append( + self._async_get_child_links_recursive( + link, visited, session=session, depth=depth + 1 + ) + ) next_results = await asyncio.gather(*sub_tasks) for sub_result in next_results: - if isinstance(sub_result, Exception): + if isinstance(sub_result, Exception) or sub_result is None: # We don't want to stop the whole process, so just ignore it - # Not standard html format or invalid url or 404 may cause this - # But we can't do anything about it. + # Not standard html format or invalid url or 404 may cause this. continue - if sub_result is not None: - flattened += sub_result - results += flattened - return list(filter(lambda x: x is not None, results)) + # locking not fully working, temporary hack to ensure deduplication + results += [r for r in sub_result if r not in results] + if close_session: + await session.close() + return results def lazy_load(self) -> Iterator[Document]: """Lazy load web pages. When use_async is True, this function will not be lazy, but it will still work in the expected way, just not lazy.""" + visited: Set[str] = set() if self.use_async: - results = asyncio.run(self._async_get_child_links_recursive(self.url)) - if results is None: - return iter([]) - else: - return iter(results) + results = asyncio.run( + self._async_get_child_links_recursive(self.url, visited) + ) + return iter(results or []) else: - return self._get_child_links_recursive(self.url) + return self._get_child_links_recursive(self.url, visited) def load(self) -> List[Document]: """Load web pages.""" diff --git a/libs/langchain/langchain/utils/html.py b/libs/langchain/langchain/utils/html.py new file mode 100644 index 0000000000..8839b4a943 --- /dev/null +++ b/libs/langchain/langchain/utils/html.py @@ -0,0 +1,69 @@ +import re +from typing import List, Union +from urllib.parse import urljoin, urlparse + +PREFIXES_TO_IGNORE = ("javascript:", "mailto:", "#") +SUFFIXES_TO_IGNORE = ( + ".css", + ".js", + ".ico", + ".png", + ".jpg", + ".jpeg", + ".gif", + ".svg", + ".csv", + ".bz2", + ".zip", + ".epub", +) +SUFFIXES_TO_IGNORE_REGEX = ( + "(?!" + "|".join([re.escape(s) + "[\#'\"]" for s in SUFFIXES_TO_IGNORE]) + ")" +) +PREFIXES_TO_IGNORE_REGEX = ( + "(?!" + "|".join([re.escape(s) for s in PREFIXES_TO_IGNORE]) + ")" +) +DEFAULT_LINK_REGEX = ( + f"href=[\"']{PREFIXES_TO_IGNORE_REGEX}((?:{SUFFIXES_TO_IGNORE_REGEX}.)*?)[\#'\"]" +) + + +def find_all_links( + raw_html: str, *, pattern: Union[str, re.Pattern, None] = None +) -> List[str]: + pattern = pattern or DEFAULT_LINK_REGEX + return list(set(re.findall(pattern, raw_html))) + + +def extract_sub_links( + raw_html: str, + base_url: str, + *, + pattern: Union[str, re.Pattern, None] = None, + prevent_outside: bool = True, +) -> List[str]: + """Extract all links from a raw html string and convert into absolute paths. + + Args: + raw_html: original html + base_url: the base url of the html + pattern: Regex to use for extracting links from raw html. + prevent_outside: If True, ignore external links which are not children + of the base url. + + Returns: + List[str]: sub links + """ + all_links = find_all_links(raw_html, pattern=pattern) + absolute_paths = set() + for link in all_links: + # Some may be absolute links like https://to/path + if link.startswith("http"): + if not prevent_outside or link.startswith(base_url): + absolute_paths.add(link) + # Some may have omitted the protocol like //to/path + elif link.startswith("//"): + absolute_paths.add(f"{urlparse(base_url).scheme}:{link}") + else: + absolute_paths.add(urljoin(base_url, link)) + return list(absolute_paths) diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_recursive_url_loader.py b/libs/langchain/tests/integration_tests/document_loaders/test_recursive_url_loader.py index c31bd369d7..b2080921dd 100644 --- a/libs/langchain/tests/integration_tests/document_loaders/test_recursive_url_loader.py +++ b/libs/langchain/tests/integration_tests/document_loaders/test_recursive_url_loader.py @@ -1,30 +1,61 @@ +import pytest as pytest + from langchain.document_loaders.recursive_url_loader import RecursiveUrlLoader +@pytest.mark.asyncio def test_async_recursive_url_loader() -> None: url = "https://docs.python.org/3.9/" loader = RecursiveUrlLoader( - url=url, extractor=lambda _: "placeholder", use_async=True, max_depth=1 + url, + extractor=lambda _: "placeholder", + use_async=True, + max_depth=3, + timeout=None, ) docs = loader.load() - assert len(docs) == 24 + assert len(docs) == 1024 assert docs[0].page_content == "placeholder" +@pytest.mark.asyncio +def test_async_recursive_url_loader_deterministic() -> None: + url = "https://docs.python.org/3.9/" + loader = RecursiveUrlLoader( + url, + use_async=True, + max_depth=3, + timeout=None, + ) + docs = sorted(loader.load(), key=lambda d: d.metadata["source"]) + docs_2 = sorted(loader.load(), key=lambda d: d.metadata["source"]) + assert docs == docs_2 + + def test_sync_recursive_url_loader() -> None: url = "https://docs.python.org/3.9/" loader = RecursiveUrlLoader( - url=url, extractor=lambda _: "placeholder", use_async=False, max_depth=1 + url, extractor=lambda _: "placeholder", use_async=False, max_depth=2 ) docs = loader.load() - assert len(docs) == 24 + assert len(docs) == 27 assert docs[0].page_content == "placeholder" +@pytest.mark.asyncio +def test_sync_async_equivalent() -> None: + url = "https://docs.python.org/3.9/" + loader = RecursiveUrlLoader(url, use_async=False, max_depth=2) + async_loader = RecursiveUrlLoader(url, use_async=False, max_depth=2) + docs = sorted(loader.load(), key=lambda d: d.metadata["source"]) + async_docs = sorted(async_loader.load(), key=lambda d: d.metadata["source"]) + assert docs == async_docs + + def test_loading_invalid_url() -> None: url = "https://this.url.is.invalid/this/is/a/test" loader = RecursiveUrlLoader( - url=url, max_depth=1, extractor=lambda _: "placeholder", use_async=False + url, max_depth=1, extractor=lambda _: "placeholder", use_async=False ) docs = loader.load() assert len(docs) == 0 diff --git a/libs/langchain/tests/unit_tests/utils/test_html.py b/libs/langchain/tests/unit_tests/utils/test_html.py new file mode 100644 index 0000000000..a5c42b6a34 --- /dev/null +++ b/libs/langchain/tests/unit_tests/utils/test_html.py @@ -0,0 +1,109 @@ +from langchain.utils.html import ( + PREFIXES_TO_IGNORE, + SUFFIXES_TO_IGNORE, + extract_sub_links, + find_all_links, +) + + +def test_find_all_links_none() -> None: + html = "Hello world" + actual = find_all_links(html) + assert actual == [] + + +def test_find_all_links_single() -> None: + htmls = [ + "href='foobar.com'", + 'href="foobar.com"', + '
hullo
', + ] + actual = [find_all_links(html) for html in htmls] + assert actual == [["foobar.com"]] * 3 + + +def test_find_all_links_multiple() -> None: + html = ( + '
hullo
' + '
buhbye
' + ) + actual = find_all_links(html) + assert sorted(actual) == [ + "/baz/cool", + "https://foobar.com", + ] + + +def test_find_all_links_ignore_suffix() -> None: + html = 'href="foobar{suffix}"' + for suffix in SUFFIXES_TO_IGNORE: + actual = find_all_links(html.format(suffix=suffix)) + assert actual == [] + + # Don't ignore if pattern doesn't occur at end of link. + html = 'href="foobar{suffix}more"' + for suffix in SUFFIXES_TO_IGNORE: + actual = find_all_links(html.format(suffix=suffix)) + assert actual == [f"foobar{suffix}more"] + + +def test_find_all_links_ignore_prefix() -> None: + html = 'href="{prefix}foobar"' + for prefix in PREFIXES_TO_IGNORE: + actual = find_all_links(html.format(prefix=prefix)) + assert actual == [] + + # Don't ignore if pattern doesn't occur at beginning of link. + html = 'href="foobar{prefix}more"' + for prefix in PREFIXES_TO_IGNORE: + # Pound signs are split on when not prefixes. + if prefix == "#": + continue + actual = find_all_links(html.format(prefix=prefix)) + assert actual == [f"foobar{prefix}more"] + + +def test_find_all_links_drop_fragment() -> None: + html = 'href="foobar.com/woah#section_one"' + actual = find_all_links(html) + assert actual == ["foobar.com/woah"] + + +def test_extract_sub_links() -> None: + html = ( + 'one' + 'two' + 'three' + 'four' + ) + expected = sorted( + [ + "https://foobar.com", + "https://foobar.com/hello", + "https://foobar.com/how/are/you/doing", + ] + ) + actual = sorted(extract_sub_links(html, "https://foobar.com")) + assert actual == expected + + actual = sorted(extract_sub_links(html, "https://foobar.com/hello")) + expected = sorted( + [ + "https://foobar.com/hello", + "https://foobar.com/how/are/you/doing", + ] + ) + assert actual == expected + + actual = sorted( + extract_sub_links(html, "https://foobar.com/hello", prevent_outside=False) + ) + expected = sorted( + [ + "https://foobar.com", + "http://baz.net", + "https://foobar.com/hello", + "https://foobar.com/how/are/you/doing", + ] + ) + assert actual == expected