Allow to specify a custom loader for GcsFileLoader (#8868)

Co-authored-by: Leonid Kuligin <kuligin@google.com>
pull/8903/head
Leonid Kuligin 1 year ago committed by GitHub
parent ff44fe4e16
commit b52a3785c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -73,13 +73,27 @@
"loader.load()" "loader.load()"
] ]
}, },
{
"cell_type": "markdown",
"id": "41c8a46f",
"metadata": {},
"source": [
"If you want to use an alternative loader, you can provide a custom function, for example:"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "eba3002d", "id": "eba3002d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [
"from langchain.document_loaders import PyPDFLoader\n",
"def load_pdf(file_path):\n",
" return PyPDFLoader(file_path)\n",
"\n",
"loader = GCSFileLoader(project_name=\"aist\", bucket=\"testing-hwc\", blob=\"fake.pdf\", loader_func=load_pdf)"
]
} }
], ],
"metadata": { "metadata": {

@ -1,5 +1,5 @@
"""Loading logic for loading documents from an GCS directory.""" """Loading logic for loading documents from an GCS directory."""
from typing import List from typing import Callable, List, Optional
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.base import BaseLoader
@ -9,17 +9,27 @@ from langchain.document_loaders.gcs_file import GCSFileLoader
class GCSDirectoryLoader(BaseLoader): class GCSDirectoryLoader(BaseLoader):
"""Loads Documents from GCS.""" """Loads Documents from GCS."""
def __init__(self, project_name: str, bucket: str, prefix: str = ""): def __init__(
self,
project_name: str,
bucket: str,
prefix: str = "",
loader_func: Optional[Callable[[str], BaseLoader]] = None,
):
"""Initialize with bucket and key name. """Initialize with bucket and key name.
Args: Args:
project_name: The name of the project for the GCS bucket. project_name: The name of the project for the GCS bucket.
bucket: The name of the GCS bucket. bucket: The name of the GCS bucket.
prefix: The prefix of the GCS bucket. prefix: The prefix of the GCS bucket.
loader_func: A loader function that instatiates a loader based on a
file_path argument. If nothing is provided, the GCSFileLoader
would use its default loader.
""" """
self.project_name = project_name self.project_name = project_name
self.bucket = bucket self.bucket = bucket
self.prefix = prefix self.prefix = prefix
self._loader_func = loader_func
def load(self) -> List[Document]: def load(self) -> List[Document]:
"""Load documents.""" """Load documents."""
@ -37,6 +47,8 @@ class GCSDirectoryLoader(BaseLoader):
# intermediate directories on the fly # intermediate directories on the fly
if blob.name.endswith("/"): if blob.name.endswith("/"):
continue continue
loader = GCSFileLoader(self.project_name, self.bucket, blob.name) loader = GCSFileLoader(
self.project_name, self.bucket, blob.name, loader_func=self._loader_func
)
docs.extend(loader.load()) docs.extend(loader.load())
return docs return docs

@ -1,7 +1,7 @@
"""Load documents from a GCS file.""" """Load documents from a GCS file."""
import os import os
import tempfile import tempfile
from typing import List from typing import Callable, List, Optional
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.base import BaseLoader
@ -11,18 +11,42 @@ from langchain.document_loaders.unstructured import UnstructuredFileLoader
class GCSFileLoader(BaseLoader): class GCSFileLoader(BaseLoader):
"""Load Documents from a GCS file.""" """Load Documents from a GCS file."""
def __init__(self, project_name: str, bucket: str, blob: str): def __init__(
self,
project_name: str,
bucket: str,
blob: str,
loader_func: Optional[Callable[[str], BaseLoader]] = None,
):
"""Initialize with bucket and key name. """Initialize with bucket and key name.
Args: Args:
project_name: The name of the project to load project_name: The name of the project to load
bucket: The name of the GCS bucket. bucket: The name of the GCS bucket.
blob: The name of the GCS blob to load. blob: The name of the GCS blob to load.
loader_func: A loader function that instatiates a loader based on a
file_path argument. If nothing is provided, the
UnstructuredFileLoader is used.
Examples:
To use an alternative PDF loader:
>> from from langchain.document_loaders import PyPDFLoader
>> loader = GCSFileLoader(..., loader_func=PyPDFLoader)
To use UnstructuredFileLoader with additional arguments:
>> loader = GCSFileLoader(...,
>> loader_func=lambda x: UnstructuredFileLoader(x, mode="elements"))
""" """
self.bucket = bucket self.bucket = bucket
self.blob = blob self.blob = blob
self.project_name = project_name self.project_name = project_name
def default_loader_func(file_path: str) -> BaseLoader:
return UnstructuredFileLoader(file_path)
self._loader_func = loader_func if loader_func else default_loader_func
def load(self) -> List[Document]: def load(self) -> List[Document]:
"""Load documents.""" """Load documents."""
try: try:
@ -44,5 +68,5 @@ class GCSFileLoader(BaseLoader):
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
# Download the file to a destination # Download the file to a destination
blob.download_to_filename(file_path) blob.download_to_filename(file_path)
loader = UnstructuredFileLoader(file_path) loader = self._loader_func(file_path)
return loader.load() return loader.load()

Loading…
Cancel
Save