From 3e879b47c1f513957e0f30eb169a6670f5f8683c Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 28 Mar 2023 15:28:33 -0700 Subject: [PATCH] Harrison/gitbook (#2044) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Irene López <45119610+ireneisdoomed@users.noreply.github.com> --- langchain/document_loaders/gitbook.py | 18 +++--- langchain/document_loaders/web_base.py | 23 +++++--- .../document_loaders/test_gitbook.py | 56 +++++++++++++++++++ 3 files changed, 81 insertions(+), 16 deletions(-) create mode 100644 tests/integration_tests/document_loaders/test_gitbook.py diff --git a/langchain/document_loaders/gitbook.py b/langchain/document_loaders/gitbook.py index 1c40b3f6..5b418449 100644 --- a/langchain/document_loaders/gitbook.py +++ b/langchain/document_loaders/gitbook.py @@ -1,5 +1,6 @@ """Loader that loads GitBook.""" from typing import Any, List, Optional +from urllib.parse import urlparse from langchain.docstore.document import Document from langchain.document_loaders.web_base import WebBaseLoader @@ -28,10 +29,15 @@ class GitbookLoader(WebBaseLoader): base_url: If `load_all_paths` is True, the relative paths are appended to this base url. Defaults to `web_page` if not set. """ - super().__init__(web_page) self.base_url = base_url or web_page if self.base_url.endswith("/"): self.base_url = self.base_url[:-1] + if load_all_paths: + # set web_path to the sitemap if we want to crawl all paths + web_paths = f"{self.base_url}/sitemap.xml" + else: + web_paths = web_page + super().__init__(web_paths) self.load_all_paths = load_all_paths def load(self) -> List[Document]: @@ -56,15 +62,9 @@ class GitbookLoader(WebBaseLoader): content = page_content_raw.get_text(separator="\n").strip() title_if_exists = page_content_raw.find("h1") title = title_if_exists.text if title_if_exists else "" - metadata = { - "source": custom_url if custom_url else self.web_path, - "title": title, - } + metadata = {"source": custom_url or self.web_path, "title": title} return Document(page_content=content, metadata=metadata) def _get_paths(self, soup: Any) -> List[str]: """Fetch all relative paths in the navbar.""" - nav = soup.find("nav") - links = nav.findAll("a") - # only return relative links - return [link.get("href") for link in links if link.get("href")[0] == "/"] + return [urlparse(loc.text).path for loc in soup.find_all("loc")] diff --git a/langchain/document_loaders/web_base.py b/langchain/document_loaders/web_base.py index cc1e3ab3..f82e6a92 100644 --- a/langchain/document_loaders/web_base.py +++ b/langchain/document_loaders/web_base.py @@ -106,19 +106,28 @@ class WebBaseLoader(BaseLoader): """Fetch all urls, then return soups for all results.""" from bs4 import BeautifulSoup - if parser is None: - parser = self.default_parser - - self._check_parser(parser) - results = asyncio.run(self.fetch_all(urls)) - return [BeautifulSoup(result, parser) for result in results] + final_results = [] + for i, result in enumerate(results): + url = urls[i] + if parser is None: + if url.endswith(".xml"): + parser = "xml" + else: + parser = self.default_parser + self._check_parser(parser) + final_results.append(BeautifulSoup(result, parser)) + + return final_results def _scrape(self, url: str, parser: Union[str, None] = None) -> Any: from bs4 import BeautifulSoup if parser is None: - parser = self.default_parser + if url.endswith(".xml"): + parser = "xml" + else: + parser = self.default_parser self._check_parser(parser) diff --git a/tests/integration_tests/document_loaders/test_gitbook.py b/tests/integration_tests/document_loaders/test_gitbook.py new file mode 100644 index 00000000..d6519a55 --- /dev/null +++ b/tests/integration_tests/document_loaders/test_gitbook.py @@ -0,0 +1,56 @@ +from typing import Optional + +import pytest + +from langchain.document_loaders.gitbook import GitbookLoader + + +class TestGitbookLoader: + @pytest.mark.parametrize( + "web_page, load_all_paths, base_url, expected_web_path", + [ + ("https://example.com/page1", False, None, "https://example.com/page1"), + ( + "https://example.com/", + True, + "https://example.com", + "https://example.com/sitemap.xml", + ), + ], + ) + def test_init( + self, + web_page: str, + load_all_paths: bool, + base_url: Optional[str], + expected_web_path: str, + ) -> None: + loader = GitbookLoader( + web_page, load_all_paths=load_all_paths, base_url=base_url + ) + print(loader.__dict__) + assert ( + loader.base_url == (base_url or web_page)[:-1] + if (base_url or web_page).endswith("/") + else (base_url or web_page) + ) + assert loader.web_path == expected_web_path + assert loader.load_all_paths == load_all_paths + + @pytest.mark.parametrize( + "web_page, expected_number_results", + [("https://platform-docs.opentargets.org/getting-started", 1)], + ) + def test_load_single_page( + self, web_page: str, expected_number_results: int + ) -> None: + loader = GitbookLoader(web_page) + result = loader.load() + assert len(result) == expected_number_results + + @pytest.mark.parametrize("web_page", [("https://platform-docs.opentargets.org/")]) + def test_load_multiple_pages(self, web_page: str) -> None: + loader = GitbookLoader(web_page, load_all_paths=True) + result = loader.load() + print(len(result)) + assert len(result) > 10