Harrison/gitbook (#2044)

Co-authored-by: Irene López <45119610+ireneisdoomed@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-03-28 15:28:33 -07:00 committed by GitHub
parent 859502b16c
commit 3e879b47c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 16 deletions

View File

@ -1,5 +1,6 @@
"""Loader that loads GitBook.""" """Loader that loads GitBook."""
from typing import Any, List, Optional from typing import Any, List, Optional
from urllib.parse import urlparse
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.document_loaders.web_base import WebBaseLoader from langchain.document_loaders.web_base import WebBaseLoader
@ -28,10 +29,15 @@ class GitbookLoader(WebBaseLoader):
base_url: If `load_all_paths` is True, the relative paths are base_url: If `load_all_paths` is True, the relative paths are
appended to this base url. Defaults to `web_page` if not set. appended to this base url. Defaults to `web_page` if not set.
""" """
super().__init__(web_page)
self.base_url = base_url or web_page self.base_url = base_url or web_page
if self.base_url.endswith("/"): if self.base_url.endswith("/"):
self.base_url = self.base_url[:-1] self.base_url = self.base_url[:-1]
if load_all_paths:
# set web_path to the sitemap if we want to crawl all paths
web_paths = f"{self.base_url}/sitemap.xml"
else:
web_paths = web_page
super().__init__(web_paths)
self.load_all_paths = load_all_paths self.load_all_paths = load_all_paths
def load(self) -> List[Document]: def load(self) -> List[Document]:
@ -56,15 +62,9 @@ class GitbookLoader(WebBaseLoader):
content = page_content_raw.get_text(separator="\n").strip() content = page_content_raw.get_text(separator="\n").strip()
title_if_exists = page_content_raw.find("h1") title_if_exists = page_content_raw.find("h1")
title = title_if_exists.text if title_if_exists else "" title = title_if_exists.text if title_if_exists else ""
metadata = { metadata = {"source": custom_url or self.web_path, "title": title}
"source": custom_url if custom_url else self.web_path,
"title": title,
}
return Document(page_content=content, metadata=metadata) return Document(page_content=content, metadata=metadata)
def _get_paths(self, soup: Any) -> List[str]: def _get_paths(self, soup: Any) -> List[str]:
"""Fetch all relative paths in the navbar.""" """Fetch all relative paths in the navbar."""
nav = soup.find("nav") return [urlparse(loc.text).path for loc in soup.find_all("loc")]
links = nav.findAll("a")
# only return relative links
return [link.get("href") for link in links if link.get("href")[0] == "/"]

View File

@ -106,18 +106,27 @@ class WebBaseLoader(BaseLoader):
"""Fetch all urls, then return soups for all results.""" """Fetch all urls, then return soups for all results."""
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
if parser is None:
parser = self.default_parser
self._check_parser(parser)
results = asyncio.run(self.fetch_all(urls)) results = asyncio.run(self.fetch_all(urls))
return [BeautifulSoup(result, parser) for result in results] final_results = []
for i, result in enumerate(results):
url = urls[i]
if parser is None:
if url.endswith(".xml"):
parser = "xml"
else:
parser = self.default_parser
self._check_parser(parser)
final_results.append(BeautifulSoup(result, parser))
return final_results
def _scrape(self, url: str, parser: Union[str, None] = None) -> Any: def _scrape(self, url: str, parser: Union[str, None] = None) -> Any:
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
if parser is None: if parser is None:
if url.endswith(".xml"):
parser = "xml"
else:
parser = self.default_parser parser = self.default_parser
self._check_parser(parser) self._check_parser(parser)

View File

@ -0,0 +1,56 @@
from typing import Optional
import pytest
from langchain.document_loaders.gitbook import GitbookLoader
class TestGitbookLoader:
@pytest.mark.parametrize(
"web_page, load_all_paths, base_url, expected_web_path",
[
("https://example.com/page1", False, None, "https://example.com/page1"),
(
"https://example.com/",
True,
"https://example.com",
"https://example.com/sitemap.xml",
),
],
)
def test_init(
self,
web_page: str,
load_all_paths: bool,
base_url: Optional[str],
expected_web_path: str,
) -> None:
loader = GitbookLoader(
web_page, load_all_paths=load_all_paths, base_url=base_url
)
print(loader.__dict__)
assert (
loader.base_url == (base_url or web_page)[:-1]
if (base_url or web_page).endswith("/")
else (base_url or web_page)
)
assert loader.web_path == expected_web_path
assert loader.load_all_paths == load_all_paths
@pytest.mark.parametrize(
"web_page, expected_number_results",
[("https://platform-docs.opentargets.org/getting-started", 1)],
)
def test_load_single_page(
self, web_page: str, expected_number_results: int
) -> None:
loader = GitbookLoader(web_page)
result = loader.load()
assert len(result) == expected_number_results
@pytest.mark.parametrize("web_page", [("https://platform-docs.opentargets.org/")])
def test_load_multiple_pages(self, web_page: str) -> None:
loader = GitbookLoader(web_page, load_all_paths=True)
result = loader.load()
print(len(result))
assert len(result) > 10