From 48ea27ba607bd1f00e5efc0337f3d54d56742377 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 1 May 2023 21:34:07 -0700 Subject: [PATCH] Harrison/blockwise sitemap (#3940) Co-authored-by: Martin Holzhauer --- langchain/document_loaders/sitemap.py | 31 +++++++++- .../document_loaders/test_sitemap.py | 62 ++++++++++++++++++- 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/langchain/document_loaders/sitemap.py b/langchain/document_loaders/sitemap.py index 3a417dd0..2b184f38 100644 --- a/langchain/document_loaders/sitemap.py +++ b/langchain/document_loaders/sitemap.py @@ -1,6 +1,7 @@ """Loader that fetches a sitemap and loads those URLs.""" +import itertools import re -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Generator, Iterable, List, Optional from langchain.document_loaders.web_base import WebBaseLoader from langchain.schema import Document @@ -10,6 +11,12 @@ def _default_parsing_function(content: Any) -> str: return str(content.get_text()) +def _batch_block(iterable: Iterable, size: int) -> Generator[List[dict], None, None]: + it = iter(iterable) + while item := list(itertools.islice(it, size)): + yield item + + class SitemapLoader(WebBaseLoader): """Loader that fetches a sitemap and loads those URLs.""" @@ -18,6 +25,8 @@ class SitemapLoader(WebBaseLoader): web_path: str, filter_urls: Optional[List[str]] = None, parsing_function: Optional[Callable] = None, + blocksize: Optional[int] = None, + blocknum: int = 0, ): """Initialize with webpage path and optional filter URLs. @@ -26,8 +35,16 @@ class SitemapLoader(WebBaseLoader): filter_urls: list of strings or regexes that will be applied to filter the urls that are parsed and loaded parsing_function: Function to parse bs4.Soup output + blocksize: number of sitemap locations per block + blocknum: the number of the block that should be loaded - zero indexed """ + if blocksize is not None and blocksize < 1: + raise ValueError("Sitemap blocksize should be at least 1") + + if blocknum < 0: + raise ValueError("Sitemap blocknum can not be lower then 0") + try: import lxml # noqa:F401 except ImportError: @@ -39,6 +56,8 @@ class SitemapLoader(WebBaseLoader): self.filter_urls = filter_urls self.parsing_function = parsing_function or _default_parsing_function + self.blocksize = blocksize + self.blocknum = blocknum def parse_sitemap(self, soup: Any) -> List[dict]: """Parse sitemap xml and load into a list of dicts.""" @@ -76,6 +95,16 @@ class SitemapLoader(WebBaseLoader): els = self.parse_sitemap(soup) + if self.blocksize is not None: + elblocks = list(_batch_block(els, self.blocksize)) + blockcount = len(elblocks) + if blockcount - 1 < self.blocknum: + raise ValueError( + "Selected sitemap does not contain enough blocks for given blocknum" + ) + else: + els = elblocks[self.blocknum] + results = self.scrape_all([el["loc"].strip() for el in els if "loc" in el]) return [ diff --git a/tests/integration_tests/document_loaders/test_sitemap.py b/tests/integration_tests/document_loaders/test_sitemap.py index 87147ec6..3ac2a59e 100644 --- a/tests/integration_tests/document_loaders/test_sitemap.py +++ b/tests/integration_tests/document_loaders/test_sitemap.py @@ -1,3 +1,5 @@ +import pytest + from langchain.document_loaders import SitemapLoader @@ -9,11 +11,69 @@ def test_sitemap() -> None: assert "🦜🔗" in documents[0].page_content +def test_sitemap_block() -> None: + """Test sitemap loader.""" + loader = SitemapLoader( + "https://langchain.readthedocs.io/sitemap.xml", blocksize=1, blocknum=1 + ) + documents = loader.load() + assert len(documents) == 1 + assert "🦜🔗" in documents[0].page_content + + +def test_sitemap_block_only_one() -> None: + """Test sitemap loader.""" + loader = SitemapLoader( + "https://langchain.readthedocs.io/sitemap.xml", blocksize=1000000, blocknum=0 + ) + documents = loader.load() + assert len(documents) > 1 + assert "🦜🔗" in documents[0].page_content + + +def test_sitemap_block_blocknum_default() -> None: + """Test sitemap loader.""" + loader = SitemapLoader( + "https://langchain.readthedocs.io/sitemap.xml", blocksize=1000000 + ) + documents = loader.load() + assert len(documents) > 1 + assert "🦜🔗" in documents[0].page_content + + +def test_sitemap_block_size_to_small() -> None: + """Test sitemap loader.""" + with pytest.raises(ValueError, match="Sitemap blocksize should be at least 1"): + SitemapLoader("https://langchain.readthedocs.io/sitemap.xml", blocksize=0) + + +def test_sitemap_block_num_to_small() -> None: + """Test sitemap loader.""" + with pytest.raises(ValueError, match="Sitemap blocknum can not be lower then 0"): + SitemapLoader( + "https://langchain.readthedocs.io/sitemap.xml", + blocksize=1000000, + blocknum=-1, + ) + + +def test_sitemap_block_does_not_exists() -> None: + """Test sitemap loader.""" + loader = SitemapLoader( + "https://langchain.readthedocs.io/sitemap.xml", blocksize=1000000, blocknum=15 + ) + with pytest.raises( + ValueError, + match="Selected sitemap does not contain enough blocks for given blocknum", + ): + loader.load() + + def test_filter_sitemap() -> None: """Test sitemap loader.""" loader = SitemapLoader( "https://langchain.readthedocs.io/sitemap.xml", - filter_urls=["https://langchain.readthedocs.io/en/stable/"], + filter_urls=["https://python.langchain.com/en/stable/"], ) documents = loader.load() assert len(documents) == 1