Harrison/rec gd (#3054)

Co-authored-by: Benjamin Scholtz <BenSchZA@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-04-17 21:02:35 -07:00 committed by GitHub
parent eee2f23a79
commit 5107fac656
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 15 deletions

View File

@ -44,7 +44,11 @@
},
"outputs": [],
"source": [
"loader = GoogleDriveLoader(folder_id=\"1yucgL9WGgWZdM1TOuKkeghlPizuzMYb5\")"
"loader = GoogleDriveLoader(\n",
" folder_id=\"1yucgL9WGgWZdM1TOuKkeghlPizuzMYb5\",\n",
" # Optional: configure whether to recursively fetch files from subfolders. Defaults to False.\n",
" recursive=False\n",
")"
]
},
{

View File

@ -10,7 +10,7 @@
# https://cloud.google.com/iam/docs/service-accounts-create
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, root_validator, validator
@ -29,6 +29,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
folder_id: Optional[str] = None
document_ids: Optional[List[str]] = None
file_ids: Optional[List[str]] = None
recursive: bool = False
@root_validator
def validate_folder_id_or_document_ids(
@ -170,35 +171,49 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
}
return Document(page_content=text, metadata=metadata)
def _load_documents_from_folder(self) -> List[Document]:
def _load_documents_from_folder(self, folder_id: str) -> List[Document]:
"""Load documents from a folder."""
from googleapiclient.discovery import build
creds = self._load_credentials()
service = build("drive", "v3", credentials=creds)
files = self._fetch_files_recursive(service, folder_id)
returns = []
for file in files:
if file["mimeType"] == "application/vnd.google-apps.document":
returns.append(self._load_document_from_id(file["id"])) # type: ignore
elif file["mimeType"] == "application/vnd.google-apps.spreadsheet":
returns.extend(self._load_sheet_from_id(file["id"])) # type: ignore
elif file["mimeType"] == "application/pdf":
returns.extend(self._load_file_from_id(file["id"])) # type: ignore
else:
pass
return returns
def _fetch_files_recursive(
self, service: Any, folder_id: str
) -> List[Dict[str, Union[str, List[str]]]]:
"""Fetch all files and subfolders recursively."""
results = (
service.files()
.list(
q=f"'{self.folder_id}' in parents",
q=f"'{folder_id}' in parents",
pageSize=1000,
includeItemsFromAllDrives=True,
supportsAllDrives=True,
fields="nextPageToken, files(id, name, mimeType)",
fields="nextPageToken, files(id, name, mimeType, parents)",
)
.execute()
)
items = results.get("files", [])
files = results.get("files", [])
returns = []
for item in items:
if item["mimeType"] == "application/vnd.google-apps.document":
returns.append(self._load_document_from_id(item["id"]))
elif item["mimeType"] == "application/vnd.google-apps.spreadsheet":
returns.extend(self._load_sheet_from_id(item["id"]))
elif item["mimeType"] == "application/pdf":
returns.extend(self._load_file_from_id(item["id"]))
for file in files:
if file["mimeType"] == "application/vnd.google-apps.folder":
if self.recursive:
returns.extend(self._fetch_files_recursive(service, file["id"]))
else:
pass
returns.append(file)
return returns
@ -256,7 +271,7 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
def load(self) -> List[Document]:
"""Load documents."""
if self.folder_id:
return self._load_documents_from_folder()
return self._load_documents_from_folder(self.folder_id)
elif self.document_ids:
return self._load_documents_from_ids()
else: