From e78c9be312e5c59ec96f22d6e531c28329ca6312 Mon Sep 17 00:00:00 2001 From: Adam Quigley Date: Wed, 17 May 2023 08:17:07 +1000 Subject: [PATCH] Add Confluence Loader unit tests (#3333) Adds some basic unit tests for the ConfluenceLoader that can be extended later. Ports this [PR from llama-hub](https://github.com/emptycrown/llama-hub/pull/208) and adapts it to `langchain`. @Jflick58 and @zywilliamli adding you here as potential reviewers --------- Co-authored-by: Dev 2049 --- langchain/document_loaders/confluence.py | 19 +- poetry.lock | 4 +- pyproject.toml | 2 + .../document_loaders/test_confluence.py | 179 ++++++++++++++++++ 4 files changed, 187 insertions(+), 17 deletions(-) create mode 100644 tests/unit_tests/document_loaders/test_confluence.py diff --git a/langchain/document_loaders/confluence.py b/langchain/document_loaders/confluence.py index 40ad9eb653..2f3dde6080 100644 --- a/langchain/document_loaders/confluence.py +++ b/langchain/document_loaders/confluence.py @@ -1,5 +1,6 @@ """Load Data from a Confluence Space""" import logging +from io import BytesIO from typing import Any, Callable, List, Optional, Union from tenacity import ( @@ -370,12 +371,10 @@ class ConfluenceLoader(BaseLoader): def process_attachment(self, page_id: str) -> List[str]: try: - import requests # noqa: F401 from PIL import Image # noqa: F401 except ImportError: raise ImportError( - "`pytesseract` or `pdf2image` or `Pillow` package not found, " - "please run `pip install pytesseract pdf2image Pillow`" + "`Pillow` package not found, " "please run `pip install Pillow`" ) # depending on setup you may also need to set the correct path for @@ -419,9 +418,6 @@ class ConfluenceLoader(BaseLoader): "please run `pip install pytesseract pdf2image`" ) - import pytesseract # noqa: F811 - from pdf2image import convert_from_bytes # noqa: F811 - response = self.confluence.request(path=link, absolute=True) text = "" @@ -444,8 +440,6 @@ class ConfluenceLoader(BaseLoader): def process_image(self, link: str) -> str: try: - from io import BytesIO # noqa: F401 - import pytesseract # noqa: F401 from PIL import Image # noqa: F401 except ImportError: @@ -472,8 +466,6 @@ class ConfluenceLoader(BaseLoader): def process_doc(self, link: str) -> str: try: - from io import BytesIO # noqa: F401 - import docx2txt # noqa: F401 except ImportError: raise ImportError( @@ -522,17 +514,14 @@ class ConfluenceLoader(BaseLoader): def process_svg(self, link: str) -> str: try: - from io import BytesIO # noqa: F401 - import pytesseract # noqa: F401 from PIL import Image # noqa: F401 from reportlab.graphics import renderPM # noqa: F401 - from reportlab.graphics.shapes import Drawing # noqa: F401 from svglib.svglib import svg2rlg # noqa: F401 except ImportError: raise ImportError( - "`pytesseract`, `Pillow`, or `svglib` package not found, " - "please run `pip install pytesseract Pillow svglib`" + "`pytesseract`, `Pillow`, `reportlab` or `svglib` package not found, " + "please run `pip install pytesseract Pillow reportlab svglib`" ) response = self.confluence.request(path=link, absolute=True) diff --git a/poetry.lock b/poetry.lock index 9040804c3f..bc21e368b5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -10291,7 +10291,7 @@ all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"] cohere = ["cohere"] embeddings = ["sentence-transformers"] -extended-testing = ["jq", "lxml", "pandas", "pdfminer-six", "pymupdf", "pypdf", "pypdfium2", "telethon", "tqdm"] +extended-testing = ["atlassian-python-api", "beautifulsoup4", "jq", "lxml", "pandas", "pdfminer-six", "pymupdf", "pypdf", "pypdfium2", "telethon", "tqdm"] hnswlib = ["docarray", "hnswlib", "protobuf"] in-memory-store = ["docarray"] llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"] @@ -10301,4 +10301,4 @@ qdrant = ["qdrant-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "b4cc0a605ec9b6ee8752f7d708a5700143815d32f699461ce6470ca44b62701a" +content-hash = "18f77265eb5eb254f4fd308bdb4b53b2e2f7175fa2323af5112b9c62b00f4632" diff --git a/pyproject.toml b/pyproject.toml index 5deb45e864..4f3fe917d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,6 +192,8 @@ extended_testing = [ "pypdfium2", "tqdm", "lxml", + "atlassian-python-api", + "beautifulsoup4", "pandas", "telethon", ] diff --git a/tests/unit_tests/document_loaders/test_confluence.py b/tests/unit_tests/document_loaders/test_confluence.py new file mode 100644 index 0000000000..a3d371bf78 --- /dev/null +++ b/tests/unit_tests/document_loaders/test_confluence.py @@ -0,0 +1,179 @@ +import unittest +from typing import Dict +from unittest.mock import MagicMock, patch + +import pytest + +from langchain.docstore.document import Document +from langchain.document_loaders.confluence import ConfluenceLoader + + +@pytest.fixture +def mock_confluence(): # type: ignore + with patch("atlassian.Confluence") as mock_confluence: + yield mock_confluence + + +@pytest.mark.requires("atlassian", "bs4", "lxml") +class TestConfluenceLoader: + CONFLUENCE_URL = "https://example.atlassian.com/wiki" + MOCK_USERNAME = "user@gmail.com" + MOCK_API_TOKEN = "api_token" + MOCK_SPACE_KEY = "spaceId123" + + def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None: + ConfluenceLoader( + url=self.CONFLUENCE_URL, + username=self.MOCK_USERNAME, + api_key=self.MOCK_API_TOKEN, + ) + mock_confluence.assert_called_once_with( + url=self.CONFLUENCE_URL, + username="user@gmail.com", + password="api_token", + cloud=True, + ) + + def test_confluence_loader_initialization_from_env( + self, mock_confluence: MagicMock + ) -> None: + with unittest.mock.patch.dict( + "os.environ", + { + "CONFLUENCE_USERNAME": self.MOCK_USERNAME, + "CONFLUENCE_API_TOKEN": self.MOCK_API_TOKEN, + }, + ): + ConfluenceLoader(url=self.CONFLUENCE_URL) + mock_confluence.assert_called_with( + url=self.CONFLUENCE_URL, username=None, password=None, cloud=True + ) + + def test_confluence_loader_load_data_invalid_args(self) -> None: + confluence_loader = ConfluenceLoader( + url=self.CONFLUENCE_URL, + username=self.MOCK_USERNAME, + api_key=self.MOCK_API_TOKEN, + ) + + with pytest.raises( + ValueError, + match="Must specify at least one among `space_key`, `page_ids`, `label`, `cql` parameters.", # noqa: E501 + ): + confluence_loader.load() + + def test_confluence_loader_load_data_by_page_ids( + self, mock_confluence: MagicMock + ) -> None: + mock_confluence.get_page_by_id.side_effect = [ + self._get_mock_page("123"), + self._get_mock_page("456"), + ] + mock_confluence.get_all_restrictions_for_content.side_effect = [ + self._get_mock_page_restrictions("123"), + self._get_mock_page_restrictions("456"), + ] + + confluence_loader = self._get_mock_confluence_loader(mock_confluence) + + mock_page_ids = ["123", "456"] + documents = confluence_loader.load(page_ids=mock_page_ids) + + assert mock_confluence.get_page_by_id.call_count == 2 + assert mock_confluence.get_all_restrictions_for_content.call_count == 2 + + assert len(documents) == 2 + assert all(isinstance(doc, Document) for doc in documents) + assert documents[0].page_content == "Content 123" + assert documents[1].page_content == "Content 456" + + assert mock_confluence.get_all_pages_from_space.call_count == 0 + assert mock_confluence.get_all_pages_by_label.call_count == 0 + assert mock_confluence.cql.call_count == 0 + assert mock_confluence.get_page_child_by_type.call_count == 0 + + def test_confluence_loader_load_data_by_space_id( + self, mock_confluence: MagicMock + ) -> None: + # one response with two pages + mock_confluence.get_all_pages_from_space.return_value = [ + self._get_mock_page("123"), + self._get_mock_page("456"), + ] + mock_confluence.get_all_restrictions_for_content.side_effect = [ + self._get_mock_page_restrictions("123"), + self._get_mock_page_restrictions("456"), + ] + + confluence_loader = self._get_mock_confluence_loader(mock_confluence) + + documents = confluence_loader.load(space_key=self.MOCK_SPACE_KEY, max_pages=2) + + assert mock_confluence.get_all_pages_from_space.call_count == 1 + + assert len(documents) == 2 + assert all(isinstance(doc, Document) for doc in documents) + assert documents[0].page_content == "Content 123" + assert documents[1].page_content == "Content 456" + + assert mock_confluence.get_page_by_id.call_count == 0 + assert mock_confluence.get_all_pages_by_label.call_count == 0 + assert mock_confluence.cql.call_count == 0 + assert mock_confluence.get_page_child_by_type.call_count == 0 + + def _get_mock_confluence_loader( + self, mock_confluence: MagicMock + ) -> ConfluenceLoader: + confluence_loader = ConfluenceLoader( + url=self.CONFLUENCE_URL, + username=self.MOCK_USERNAME, + api_key=self.MOCK_API_TOKEN, + ) + confluence_loader.confluence = mock_confluence + return confluence_loader + + def _get_mock_page(self, page_id: str) -> Dict: + return { + "id": f"{page_id}", + "title": f"Page {page_id}", + "body": {"storage": {"value": f"

Content {page_id}

"}}, + "status": "current", + "type": "page", + "_links": { + "self": f"{self.CONFLUENCE_URL}/rest/api/content/{page_id}", + "tinyui": "/x/tiny_ui_link", + "editui": f"/pages/resumedraft.action?draftId={page_id}", + "webui": f"/spaces/{self.MOCK_SPACE_KEY}/overview", + }, + } + + def _get_mock_page_restrictions(self, page_id: str) -> Dict: + return { + "read": { + "operation": "read", + "restrictions": { + "user": {"results": [], "start": 0, "limit": 200, "size": 0}, + "group": {"results": [], "start": 0, "limit": 200, "size": 0}, + }, + "_expandable": {"content": f"/rest/api/content/{page_id}"}, + "_links": { + "self": f"{self.CONFLUENCE_URL}/rest/api/content/{page_id}/restriction/byOperation/read" # noqa: E501 + }, + }, + "update": { + "operation": "update", + "restrictions": { + "user": {"results": [], "start": 0, "limit": 200, "size": 0}, + "group": {"results": [], "start": 0, "limit": 200, "size": 0}, + }, + "_expandable": {"content": f"/rest/api/content/{page_id}"}, + "_links": { + "self": f"{self.CONFLUENCE_URL}/rest/api/content/{page_id}/restriction/byOperation/update" # noqa: E501 + }, + }, + "_links": { + "self": f"{self.CONFLUENCE_URL}/rest/api/content/{page_id}/restriction/byOperation", # noqa: E501 + "base": self.CONFLUENCE_URL, + "context": "/wiki", + }, + }