diff --git a/docs/modules/indexes/document_loaders.rst b/docs/modules/indexes/document_loaders.rst index c11bc78394..6ac4d9517b 100644 --- a/docs/modules/indexes/document_loaders.rst +++ b/docs/modules/indexes/document_loaders.rst @@ -130,6 +130,7 @@ We need access tokens and sometime other parameters to get access to these datas ./document_loaders/examples/notion.ipynb ./document_loaders/examples/obsidian.ipynb ./document_loaders/examples/psychic.ipynb + ./document_loaders/examples/pyspark_dataframe.ipynb ./document_loaders/examples/readthedocs_documentation.ipynb ./document_loaders/examples/reddit.ipynb ./document_loaders/examples/roam.ipynb diff --git a/docs/modules/indexes/document_loaders/examples/pyspark_dataframe.ipynb b/docs/modules/indexes/document_loaders/examples/pyspark_dataframe.ipynb new file mode 100644 index 0000000000..a09383bbaf --- /dev/null +++ b/docs/modules/indexes/document_loaders/examples/pyspark_dataframe.ipynb @@ -0,0 +1,97 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PySpack DataFrame Loader\n", + "\n", + "This shows how to load data from a PySpark DataFrame" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install pyspark" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pyspark.sql import SparkSession" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "spark = SparkSession.builder.getOrCreate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = spark.read.csv('example_data/mlb_teams_2012.csv', header=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import PySparkDataFrameLoader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loader = PySparkDataFrameLoader(spark, df, page_content_column=\"Team\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loader.load()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/document_loaders/__init__.py b/langchain/document_loaders/__init__.py index e96c4efeb1..d89718ac82 100644 --- a/langchain/document_loaders/__init__.py +++ b/langchain/document_loaders/__init__.py @@ -74,6 +74,7 @@ from langchain.document_loaders.pdf import ( ) from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader from langchain.document_loaders.psychic import PsychicLoader +from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader from langchain.document_loaders.python import PythonLoader from langchain.document_loaders.readthedocs import ReadTheDocsLoader from langchain.document_loaders.reddit import RedditPostsLoader @@ -188,6 +189,7 @@ __all__ = [ "PyPDFDirectoryLoader", "PyPDFLoader", "PyPDFium2Loader", + "PySparkDataFrameLoader", "PythonLoader", "ReadTheDocsLoader", "RedditPostsLoader", diff --git a/langchain/document_loaders/pyspark_dataframe.py b/langchain/document_loaders/pyspark_dataframe.py new file mode 100644 index 0000000000..fe8fe4e79b --- /dev/null +++ b/langchain/document_loaders/pyspark_dataframe.py @@ -0,0 +1,80 @@ +"""Load from a Spark Dataframe object""" +import itertools +import logging +import sys +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple + +import psutil + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader + +logger = logging.getLogger(__file__) + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + + +class PySparkDataFrameLoader(BaseLoader): + """Load PySpark DataFrames""" + + def __init__( + self, + spark_session: Optional["SparkSession"] = None, + df: Optional[Any] = None, + page_content_column: str = "text", + fraction_of_memory: float = 0.1, + ): + """Initialize with a Spark DataFrame object.""" + try: + from pyspark.sql import DataFrame, SparkSession + except ImportError: + raise ValueError( + "pyspark is not installed. " + "Please install it with `pip install pyspark`" + ) + + self.spark = ( + spark_session if spark_session else SparkSession.builder.getOrCreate() + ) + + if not isinstance(df, DataFrame): + raise ValueError( + f"Expected data_frame to be a PySpark DataFrame, got {type(df)}" + ) + self.df = df + self.page_content_column = page_content_column + self.fraction_of_memory = fraction_of_memory + self.num_rows, self.max_num_rows = self.get_num_rows() + self.rdd_df = self.df.rdd.map(list) + self.column_names = self.df.columns + + def get_num_rows(self) -> Tuple[int, int]: + """Gets the amount of "feasible" rows for the DataFrame""" + row = self.df.limit(1).collect()[0] + estimated_row_size = sys.getsizeof(row) + mem_info = psutil.virtual_memory() + available_memory = mem_info.available + max_num_rows = int( + (available_memory / estimated_row_size) * self.fraction_of_memory + ) + return min(max_num_rows, self.df.count()), max_num_rows + + def lazy_load(self) -> Iterator[Document]: + """A lazy loader for document content.""" + for row in self.rdd_df.toLocalIterator(): + metadata = {self.column_names[i]: row[i] for i in range(len(row))} + text = metadata[self.page_content_column] + metadata.pop(self.page_content_column) + yield Document(page_content=text, metadata=metadata) + + def load(self) -> List[Document]: + """Load from the dataframe.""" + if self.df.count() > self.max_num_rows: + logger.warning( + f"The number of DataFrame rows is {self.df.count()}, " + f"but we will only include the amount " + f"of rows that can reasonably fit in memory: {self.num_rows}." + ) + lazy_load_iterator = self.lazy_load() + return list(itertools.islice(lazy_load_iterator, self.num_rows)) diff --git a/poetry.lock b/poetry.lock index 70668ca5a3..633bc6cd12 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6643,6 +6643,18 @@ pytz = "*" requests = "*" requests-oauthlib = ">=0.4.1" +[[package]] +name = "py4j" +version = "0.10.9.7" +description = "Enables Python programs to dynamically access arbitrary Java objects" +category = "main" +optional = true +python-versions = "*" +files = [ + {file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"}, + {file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"}, +] + [[package]] name = "pyaes" version = "1.6.1" @@ -7229,6 +7241,27 @@ files = [ {file = "PySocks-1.7.1.tar.gz", hash = "sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0"}, ] +[[package]] +name = "pyspark" +version = "3.4.0" +description = "Apache Spark Python API" +category = "main" +optional = true +python-versions = ">=3.7" +files = [ + {file = "pyspark-3.4.0.tar.gz", hash = "sha256:167a23e11854adb37f8602de6fcc3a4f96fd5f1e323b9bb83325f38408c5aafd"}, +] + +[package.dependencies] +py4j = "0.10.9.7" + +[package.extras] +connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.48.1)", "grpcio-status (>=1.48.1)", "numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] +ml = ["numpy (>=1.15)"] +mllib = ["numpy (>=1.15)"] +pandas-on-spark = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] +sql = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"] + [[package]] name = "pytesseract" version = "0.3.10" @@ -10920,7 +10953,7 @@ azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices- cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"] +extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"] llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] qdrant = ["qdrant-client"] @@ -10929,4 +10962,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "1033e47cdab7d3a15fb9322bad64609f77fd3befc47c1a01dc91b22cbbc708a3" +content-hash = "b3dc23f376de141d22b729d038144a1e6d66983a910160c3500fe0d79f8e5917" diff --git a/pyproject.toml b/pyproject.toml index 7b3cb8094c..950a614cb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ azure-cognitiveservices-speech = {version = "^1.28.0", optional = true} py-trello = {version = "^0.19.0", optional = true} momento = {version = "^1.5.0", optional = true} bibtexparser = {version = "^1.4.0", optional = true} +pyspark = {version = "^3.4.0", optional = true} [tool.poetry.group.docs.dependencies] autodoc_pydantic = "^1.8.0" @@ -301,6 +302,7 @@ extended_testing = [ "html2text", "py-trello", "scikit-learn", + "pyspark", ] [tool.ruff] diff --git a/tests/integration_tests/document_loaders/test_pyspark_dataframe_loader.py b/tests/integration_tests/document_loaders/test_pyspark_dataframe_loader.py new file mode 100644 index 0000000000..74bdb3e291 --- /dev/null +++ b/tests/integration_tests/document_loaders/test_pyspark_dataframe_loader.py @@ -0,0 +1,38 @@ +import random +import string + +from langchain.docstore.document import Document +from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader + + +def test_pyspark_loader_load_valid_data() -> None: + from pyspark.sql import SparkSession + + # Requires a session to be set up + spark = SparkSession.builder.getOrCreate() + data = [ + (random.choice(string.ascii_letters), random.randint(0, 1)) for _ in range(3) + ] + df = spark.createDataFrame(data, ["text", "label"]) + + expected_docs = [ + Document( + page_content=data[0][0], + metadata={"label": data[0][1]}, + ), + Document( + page_content=data[1][0], + metadata={"label": data[1][1]}, + ), + Document( + page_content=data[2][0], + metadata={"label": data[2][1]}, + ), + ] + + loader = PySparkDataFrameLoader( + spark_session=spark, df=df, page_content_column="text" + ) + result = loader.load() + + assert result == expected_docs