refactor(langchain): improve type annotations in url_playwright and its test

This commit is contained in:
Youngwook Kim 2023-08-09 15:56:24 +09:00
parent 04fcd2d2e0
commit 429de77b3b
2 changed files with 34 additions and 16 deletions

View File

@ -2,11 +2,16 @@
"""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
if TYPE_CHECKING:
from playwright.async_api import AsyncBrowser, AsyncPage, AsyncResponse
from playwright.sync_api import Browser, Page, Response
logger = logging.getLogger(__name__)
@ -18,7 +23,7 @@ class PlaywrightEvaluator(ABC):
"""
@abstractmethod
def evaluate(self, page, browser, response):
def evaluate(self, page: "Page", browser: "Browser", response: "Response") -> str:
"""Synchronously process the page and return the resulting text.
Args:
@ -32,7 +37,9 @@ class PlaywrightEvaluator(ABC):
pass
@abstractmethod
async def evaluate_async(self, page, browser, response):
async def evaluate_async(
self, page: "AsyncPage", browser: "AsyncBrowser", response: "AsyncResponse"
) -> str:
"""Asynchronously process the page and return the resulting text.
Args:
@ -50,7 +57,7 @@ class UnstructuredHtmlEvaluator(PlaywrightEvaluator):
"""Evaluates the page HTML content using the `unstructured` library."""
def __init__(self, remove_selectors: Optional[List[str]] = None):
"""Initialize UnstructuredHtmlEvaluator and check if `unstructured` package is installed."""
"""Initialize UnstructuredHtmlEvaluator."""
try:
import unstructured # noqa:F401
except ImportError:
@ -61,8 +68,8 @@ class UnstructuredHtmlEvaluator(PlaywrightEvaluator):
self.remove_selectors = remove_selectors
def evaluate(self, page, browser, response):
"""Synchronously process the HTML content of the page and return a text string."""
def evaluate(self, page: "Page", browser: "Browser", response: "Response") -> str:
"""Synchronously process the HTML content of the page."""
from unstructured.partition.html import partition_html
for selector in self.remove_selectors or []:
@ -75,8 +82,10 @@ class UnstructuredHtmlEvaluator(PlaywrightEvaluator):
elements = partition_html(text=page_source)
return "\n\n".join([str(el) for el in elements])
async def evaluate_async(self, page, browser, response):
"""Asynchronously process the HTML content of the page and return a text string."""
async def evaluate_async(
self, page: "AsyncPage", browser: "AsyncBrowser", response: "AsyncResponse"
) -> str:
"""Asynchronously process the HTML content of the page."""
from unstructured.partition.html import partition_html
for selector in self.remove_selectors or []:
@ -126,7 +135,7 @@ class PlaywrightURLLoader(BaseLoader):
"`remove_selectors` and `evaluator` cannot be both not None"
)
# Use the provided evaluator, if any, otherwise, use the default UnstructuredHtmlEvaluator.
# Use the provided evaluator, if any, otherwise, use the default.
self.evaluator = evaluator or UnstructuredHtmlEvaluator(remove_selectors)
def load(self) -> List[Document]:

View File

@ -1,16 +1,25 @@
"""Tests for the Playwright URL loader"""
from typing import TYPE_CHECKING
import pytest
from langchain.document_loaders import PlaywrightURLLoader
from langchain.document_loaders.url_playwright import PlaywrightEvaluator
if TYPE_CHECKING:
from playwright.async_api import AsyncBrowser, AsyncPage, AsyncResponse
from playwright.sync_api import Browser, Page, Response
class TestEvaluator(PageEvaluator):
class TestEvaluator(PlaywrightEvaluator):
"""A simple evaluator for testing purposes."""
def evaluate(self, page, browser, response):
def evaluate(self, page: "Page", browser: "Browser", response: "Response") -> str:
return "test"
async def evaluate_async(self, page, browser, response):
async def evaluate_async(
self, page: "AsyncPage", browser: "AsyncBrowser", response: "AsyncResponse"
) -> str:
return "test"
@ -56,13 +65,13 @@ def test_playwright_url_loader_with_custom_evaluator() -> None:
urls = ["https://www.youtube.com/watch?v=dQw4w9WgXcQ"]
loader = PlaywrightURLLoader(
urls=urls,
page_evaluator=TestEvaluator(),
evaluator=TestEvaluator(),
continue_on_failure=False,
headless=True,
)
docs = loader.load()
assert len(docs) == 1
assert docs[0].page_content == "test-"
assert docs[0].page_content == "test"
@pytest.mark.asyncio
@ -71,10 +80,10 @@ async def test_playwright_async_url_loader_with_custom_evaluator() -> None:
urls = ["https://www.youtube.com/watch?v=dQw4w9WgXcQ"]
loader = PlaywrightURLLoader(
urls=urls,
page_evaluator=TestEvaluator(),
evaluator=TestEvaluator(),
continue_on_failure=False,
headless=True,
)
docs = await loader.aload()
assert len(docs) == 2
assert len(docs) == 1
assert docs[0].page_content == "test"