Sitemap specify default filter url (#11925)

Specify default filter URL in sitemap loader and add a security note

---------

Co-authored-by: Predrag Gruevski <2348618+obi1kenobi@users.noreply.github.com>
This commit is contained in:
Eugene Yurtsev 2023-10-17 13:19:27 -04:00 committed by GitHub
parent ba0d729961
commit 90e9ec6962
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 19 deletions

View File

@ -1,6 +1,7 @@
import itertools import itertools
import re import re
from typing import Any, Callable, Generator, Iterable, List, Optional from typing import Any, Callable, Generator, Iterable, List, Optional, Tuple
from urllib.parse import urlparse
from langchain.document_loaders.web_base import WebBaseLoader from langchain.document_loaders.web_base import WebBaseLoader
from langchain.schema import Document from langchain.schema import Document
@ -20,8 +21,47 @@ def _batch_block(iterable: Iterable, size: int) -> Generator[List[dict], None, N
yield item yield item
def _extract_scheme_and_domain(url: str) -> Tuple[str, str]:
"""Extract the scheme + domain from a given URL.
Args:
url (str): The input URL.
Returns:
return a 2-tuple of scheme and domain
"""
parsed_uri = urlparse(url)
return parsed_uri.scheme, parsed_uri.netloc
class SitemapLoader(WebBaseLoader): class SitemapLoader(WebBaseLoader):
"""Load a sitemap and its URLs.""" """Load a sitemap and its URLs.
**Security Note**: This loader can be used to load all URLs specified in a sitemap.
If a malicious actor gets access to the sitemap, they could force
the server to load URLs from other domains by modifying the sitemap.
This could lead to server-side request forgery (SSRF) attacks; e.g.,
with the attacker forcing the server to load URLs from internal
service endpoints that are not publicly accessible. While the attacker
may not immediately gain access to this data, this data could leak
into downstream systems (e.g., data loader is used to load data for indexing).
This loader is a crawler and web crawlers should generally NOT be deployed
with network access to any internal servers.
Control access to who can submit crawling requests and what network access
the crawler has.
By default, the loader will only load URLs from the same domain as the sitemap
if the site map is not a local file. This can be disabled by setting
restrict_to_same_domain to False (not recommended).
If the site map is a local file, no such risk mitigation is applied by default.
Use the filter URLs argument to limit which URLs can be loaded.
See https://python.langchain.com/docs/security
"""
def __init__( def __init__(
self, self,
@ -33,14 +73,22 @@ class SitemapLoader(WebBaseLoader):
meta_function: Optional[Callable] = None, meta_function: Optional[Callable] = None,
is_local: bool = False, is_local: bool = False,
continue_on_failure: bool = False, continue_on_failure: bool = False,
restrict_to_same_domain: bool = True,
**kwargs: Any, **kwargs: Any,
): ):
"""Initialize with webpage path and optional filter URLs. """Initialize with webpage path and optional filter URLs.
Args: Args:
web_path: url of the sitemap. can also be a local path web_path: url of the sitemap. can also be a local path
filter_urls: list of strings or regexes that will be applied to filter the filter_urls: a list of regexes. If specified, only
urls that are parsed and loaded URLS that match one of the filter URLs will be loaded.
*WARNING* The filter URLs are interpreted as regular expressions.
Remember to escape special characters if you do not want them to be
interpreted as regular expression syntax. For example, `.` appears
frequently in URLs and should be escaped if you want to match a literal
`.` rather than any character.
restrict_to_same_domain takes precedence over filter_urls when
restrict_to_same_domain is True and the sitemap is not a local file.
parsing_function: Function to parse bs4.Soup output parsing_function: Function to parse bs4.Soup output
blocksize: number of sitemap locations per block blocksize: number of sitemap locations per block
blocknum: the number of the block that should be loaded - zero indexed. blocknum: the number of the block that should be loaded - zero indexed.
@ -53,6 +101,9 @@ class SitemapLoader(WebBaseLoader):
occurs loading a url, emitting a warning instead of raising an occurs loading a url, emitting a warning instead of raising an
exception. Setting this to True makes the loader more robust, but also exception. Setting this to True makes the loader more robust, but also
may result in missing data. Default: False may result in missing data. Default: False
restrict_to_same_domain: whether to restrict loading to URLs to the same
domain as the sitemap. Attention: This is only applied if the sitemap
is not a local file!
""" """
if blocksize is not None and blocksize < 1: if blocksize is not None and blocksize < 1:
@ -65,12 +116,17 @@ class SitemapLoader(WebBaseLoader):
import lxml # noqa:F401 import lxml # noqa:F401
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"lxml package not found, please install it with " "`pip install lxml`" "lxml package not found, please install it with `pip install lxml`"
) )
super().__init__(web_paths=[web_path], **kwargs) super().__init__(web_paths=[web_path], **kwargs)
self.filter_urls = filter_urls # Define a list of URL patterns (interpreted as regular expressions) that
# will be allowed to be loaded.
# restrict_to_same_domain takes precedence over filter_urls when
# restrict_to_same_domain is True and the sitemap is not a local file.
self.allow_url_patterns = filter_urls
self.restrict_to_same_domain = restrict_to_same_domain
self.parsing_function = parsing_function or _default_parsing_function self.parsing_function = parsing_function or _default_parsing_function
self.meta_function = meta_function or _default_meta_function self.meta_function = meta_function or _default_meta_function
self.blocksize = blocksize self.blocksize = blocksize
@ -96,8 +152,15 @@ class SitemapLoader(WebBaseLoader):
# Strip leading and trailing whitespace and newlines # Strip leading and trailing whitespace and newlines
loc_text = loc.text.strip() loc_text = loc.text.strip()
if self.filter_urls and not any( if self.restrict_to_same_domain and not self.is_local:
re.match(r, loc_text) for r in self.filter_urls if _extract_scheme_and_domain(loc_text) != _extract_scheme_and_domain(
self.web_path
):
continue
if self.allow_url_patterns and not any(
re.match(regexp_pattern, loc_text)
for regexp_pattern in self.allow_url_patterns
): ):
continue continue

View File

@ -4,11 +4,12 @@ from typing import Any
import pytest import pytest
from langchain.document_loaders import SitemapLoader from langchain.document_loaders import SitemapLoader
from langchain.document_loaders.sitemap import _extract_scheme_and_domain
def test_sitemap() -> None: def test_sitemap() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
loader = SitemapLoader("https://langchain.readthedocs.io/sitemap.xml") loader = SitemapLoader("https://api.python.langchain.com/sitemap.xml")
documents = loader.load() documents = loader.load()
assert len(documents) > 1 assert len(documents) > 1
assert "LangChain Python API" in documents[0].page_content assert "LangChain Python API" in documents[0].page_content
@ -17,7 +18,7 @@ def test_sitemap() -> None:
def test_sitemap_block() -> None: def test_sitemap_block() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
loader = SitemapLoader( loader = SitemapLoader(
"https://langchain.readthedocs.io/sitemap.xml", blocksize=1, blocknum=1 "https://api.python.langchain.com/sitemap.xml", blocksize=1, blocknum=1
) )
documents = loader.load() documents = loader.load()
assert len(documents) == 1 assert len(documents) == 1
@ -27,7 +28,7 @@ def test_sitemap_block() -> None:
def test_sitemap_block_only_one() -> None: def test_sitemap_block_only_one() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
loader = SitemapLoader( loader = SitemapLoader(
"https://langchain.readthedocs.io/sitemap.xml", blocksize=1000000, blocknum=0 "https://api.python.langchain.com/sitemap.xml", blocksize=1000000, blocknum=0
) )
documents = loader.load() documents = loader.load()
assert len(documents) > 1 assert len(documents) > 1
@ -37,7 +38,7 @@ def test_sitemap_block_only_one() -> None:
def test_sitemap_block_blocknum_default() -> None: def test_sitemap_block_blocknum_default() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
loader = SitemapLoader( loader = SitemapLoader(
"https://langchain.readthedocs.io/sitemap.xml", blocksize=1000000 "https://api.python.langchain.com/sitemap.xml", blocksize=1000000
) )
documents = loader.load() documents = loader.load()
assert len(documents) > 1 assert len(documents) > 1
@ -47,14 +48,14 @@ def test_sitemap_block_blocknum_default() -> None:
def test_sitemap_block_size_to_small() -> None: def test_sitemap_block_size_to_small() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
with pytest.raises(ValueError, match="Sitemap blocksize should be at least 1"): with pytest.raises(ValueError, match="Sitemap blocksize should be at least 1"):
SitemapLoader("https://langchain.readthedocs.io/sitemap.xml", blocksize=0) SitemapLoader("https://api.python.langchain.com/sitemap.xml", blocksize=0)
def test_sitemap_block_num_to_small() -> None: def test_sitemap_block_num_to_small() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
with pytest.raises(ValueError, match="Sitemap blocknum can not be lower then 0"): with pytest.raises(ValueError, match="Sitemap blocknum can not be lower then 0"):
SitemapLoader( SitemapLoader(
"https://langchain.readthedocs.io/sitemap.xml", "https://api.python.langchain.com/sitemap.xml",
blocksize=1000000, blocksize=1000000,
blocknum=-1, blocknum=-1,
) )
@ -63,7 +64,7 @@ def test_sitemap_block_num_to_small() -> None:
def test_sitemap_block_does_not_exists() -> None: def test_sitemap_block_does_not_exists() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
loader = SitemapLoader( loader = SitemapLoader(
"https://langchain.readthedocs.io/sitemap.xml", blocksize=1000000, blocknum=15 "https://api.python.langchain.com/sitemap.xml", blocksize=1000000, blocknum=15
) )
with pytest.raises( with pytest.raises(
ValueError, ValueError,
@ -75,7 +76,7 @@ def test_sitemap_block_does_not_exists() -> None:
def test_filter_sitemap() -> None: def test_filter_sitemap() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
loader = SitemapLoader( loader = SitemapLoader(
"https://langchain.readthedocs.io/sitemap.xml", "https://api.python.langchain.com/sitemap.xml",
filter_urls=["https://api.python.langchain.com/en/stable/"], filter_urls=["https://api.python.langchain.com/en/stable/"],
) )
documents = loader.load() documents = loader.load()
@ -89,7 +90,7 @@ def test_sitemap_metadata() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
loader = SitemapLoader( loader = SitemapLoader(
"https://langchain.readthedocs.io/sitemap.xml", "https://api.python.langchain.com/sitemap.xml",
meta_function=sitemap_metadata_one, meta_function=sitemap_metadata_one,
) )
documents = loader.load() documents = loader.load()
@ -107,7 +108,7 @@ def test_sitemap_metadata_extraction() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
loader = SitemapLoader( loader = SitemapLoader(
"https://langchain.readthedocs.io/sitemap.xml", "https://api.python.langchain.com/sitemap.xml",
meta_function=sitemap_metadata_two, meta_function=sitemap_metadata_two,
) )
documents = loader.load() documents = loader.load()
@ -118,7 +119,7 @@ def test_sitemap_metadata_extraction() -> None:
def test_sitemap_metadata_default() -> None: def test_sitemap_metadata_default() -> None:
"""Test sitemap loader.""" """Test sitemap loader."""
loader = SitemapLoader("https://langchain.readthedocs.io/sitemap.xml") loader = SitemapLoader("https://api.python.langchain.com/sitemap.xml")
documents = loader.load() documents = loader.load()
assert len(documents) > 1 assert len(documents) > 1
assert "source" in documents[0].metadata assert "source" in documents[0].metadata
@ -132,3 +133,23 @@ def test_local_sitemap() -> None:
documents = loader.load() documents = loader.load()
assert len(documents) > 1 assert len(documents) > 1
assert "🦜️🔗" in documents[0].page_content assert "🦜️🔗" in documents[0].page_content
def test_extract_domain() -> None:
"""Test domain extraction."""
assert _extract_scheme_and_domain("https://js.langchain.com/sitemap.xml") == (
"https",
"js.langchain.com",
)
assert _extract_scheme_and_domain("http://example.com/path/to/page") == (
"http",
"example.com",
)
assert _extract_scheme_and_domain("ftp://files.example.com") == (
"ftp",
"files.example.com",
)
assert _extract_scheme_and_domain("https://deep.subdomain.example.com") == (
"https",
"deep.subdomain.example.com",
)