diff --git a/libs/langchain/langchain/document_loaders/recursive_url_loader.py b/libs/langchain/langchain/document_loaders/recursive_url_loader.py index 2b93996a73..4781609ac0 100644 --- a/libs/langchain/langchain/document_loaders/recursive_url_loader.py +++ b/libs/langchain/langchain/document_loaders/recursive_url_loader.py @@ -145,7 +145,8 @@ class RecursiveUrlLoader(BaseLoader): # Store the visited links and recursively visit the children sub_links = extract_sub_links( response.text, - self.url, + url, + base_url=self.url, pattern=self.link_regex, prevent_outside=self.prevent_outside, ) @@ -224,7 +225,8 @@ class RecursiveUrlLoader(BaseLoader): if depth < self.max_depth - 1: sub_links = extract_sub_links( text, - self.url, + url, + base_url=self.url, pattern=self.link_regex, prevent_outside=self.prevent_outside, ) diff --git a/libs/langchain/langchain/utils/html.py b/libs/langchain/langchain/utils/html.py index ebdd7b86ba..d1f76cdabd 100644 --- a/libs/langchain/langchain/utils/html.py +++ b/libs/langchain/langchain/utils/html.py @@ -1,5 +1,5 @@ import re -from typing import List, Union +from typing import List, Optional, Union from urllib.parse import urljoin, urlparse PREFIXES_TO_IGNORE = ("javascript:", "mailto:", "#") @@ -37,16 +37,18 @@ def find_all_links( def extract_sub_links( raw_html: str, - base_url: str, + url: str, *, + base_url: Optional[str] = None, pattern: Union[str, re.Pattern, None] = None, prevent_outside: bool = True, ) -> List[str]: """Extract all links from a raw html string and convert into absolute paths. Args: - raw_html: original html - base_url: the base url of the html + raw_html: original html. + url: the url of the html. + base_url: the base url to check for outside links against. pattern: Regex to use for extracting links from raw html. prevent_outside: If True, ignore external links which are not children of the base url. @@ -54,6 +56,7 @@ def extract_sub_links( Returns: List[str]: sub links """ + base_url = base_url if base_url is not None else url all_links = find_all_links(raw_html, pattern=pattern) absolute_paths = set() for link in all_links: @@ -62,9 +65,9 @@ def extract_sub_links( absolute_paths.add(link) # Some may have omitted the protocol like //to/path elif link.startswith("//"): - absolute_paths.add(f"{urlparse(base_url).scheme}:{link}") + absolute_paths.add(f"{urlparse(url).scheme}:{link}") else: - absolute_paths.add(urljoin(base_url, link)) + absolute_paths.add(urljoin(url, link)) if prevent_outside: return [p for p in absolute_paths if p.startswith(base_url)] return list(absolute_paths) diff --git a/libs/langchain/tests/unit_tests/utils/test_html.py b/libs/langchain/tests/unit_tests/utils/test_html.py index eaaa3544e8..b961f966d9 100644 --- a/libs/langchain/tests/unit_tests/utils/test_html.py +++ b/libs/langchain/tests/unit_tests/utils/test_html.py @@ -102,3 +102,28 @@ def test_extract_sub_links() -> None: ] ) assert actual == expected + + +def test_extract_sub_links_base() -> None: + html = ( + 'one' + 'two' + 'three' + 'four' + '' + ) + + expected = sorted( + [ + "https://foobar.com", + "https://foobar.com/hello", + "https://foobar.com/how/are/you/doing", + "https://foobar.com/hello/alexis.html", + ] + ) + actual = sorted( + extract_sub_links( + html, "https://foobar.com/hello/bill.html", base_url="https://foobar.com" + ) + ) + assert actual == expected