Harrison/spark reader (#5405)

Co-authored-by: Rithwik Ediga Lakhamsani <rithwik.ediga@databricks.com>
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
searx_updates
Harrison Chase 12 months ago committed by GitHub
parent 8259f9b7fa
commit 760632b292
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -130,6 +130,7 @@ We need access tokens and sometime other parameters to get access to these datas
./document_loaders/examples/notion.ipynb
./document_loaders/examples/obsidian.ipynb
./document_loaders/examples/psychic.ipynb
./document_loaders/examples/pyspark_dataframe.ipynb
./document_loaders/examples/readthedocs_documentation.ipynb
./document_loaders/examples/reddit.ipynb
./document_loaders/examples/roam.ipynb

@ -0,0 +1,97 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PySpack DataFrame Loader\n",
"\n",
"This shows how to load data from a PySpark DataFrame"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#!pip install pyspark"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql import SparkSession"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"spark = SparkSession.builder.getOrCreate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = spark.read.csv('example_data/mlb_teams_2012.csv', header=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import PySparkDataFrameLoader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loader = PySparkDataFrameLoader(spark, df, page_content_column=\"Team\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loader.load()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -74,6 +74,7 @@ from langchain.document_loaders.pdf import (
)
from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader
from langchain.document_loaders.psychic import PsychicLoader
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
from langchain.document_loaders.python import PythonLoader
from langchain.document_loaders.readthedocs import ReadTheDocsLoader
from langchain.document_loaders.reddit import RedditPostsLoader
@ -188,6 +189,7 @@ __all__ = [
"PyPDFDirectoryLoader",
"PyPDFLoader",
"PyPDFium2Loader",
"PySparkDataFrameLoader",
"PythonLoader",
"ReadTheDocsLoader",
"RedditPostsLoader",

@ -0,0 +1,80 @@
"""Load from a Spark Dataframe object"""
import itertools
import logging
import sys
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
import psutil
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
logger = logging.getLogger(__file__)
if TYPE_CHECKING:
from pyspark.sql import SparkSession
class PySparkDataFrameLoader(BaseLoader):
"""Load PySpark DataFrames"""
def __init__(
self,
spark_session: Optional["SparkSession"] = None,
df: Optional[Any] = None,
page_content_column: str = "text",
fraction_of_memory: float = 0.1,
):
"""Initialize with a Spark DataFrame object."""
try:
from pyspark.sql import DataFrame, SparkSession
except ImportError:
raise ValueError(
"pyspark is not installed. "
"Please install it with `pip install pyspark`"
)
self.spark = (
spark_session if spark_session else SparkSession.builder.getOrCreate()
)
if not isinstance(df, DataFrame):
raise ValueError(
f"Expected data_frame to be a PySpark DataFrame, got {type(df)}"
)
self.df = df
self.page_content_column = page_content_column
self.fraction_of_memory = fraction_of_memory
self.num_rows, self.max_num_rows = self.get_num_rows()
self.rdd_df = self.df.rdd.map(list)
self.column_names = self.df.columns
def get_num_rows(self) -> Tuple[int, int]:
"""Gets the amount of "feasible" rows for the DataFrame"""
row = self.df.limit(1).collect()[0]
estimated_row_size = sys.getsizeof(row)
mem_info = psutil.virtual_memory()
available_memory = mem_info.available
max_num_rows = int(
(available_memory / estimated_row_size) * self.fraction_of_memory
)
return min(max_num_rows, self.df.count()), max_num_rows
def lazy_load(self) -> Iterator[Document]:
"""A lazy loader for document content."""
for row in self.rdd_df.toLocalIterator():
metadata = {self.column_names[i]: row[i] for i in range(len(row))}
text = metadata[self.page_content_column]
metadata.pop(self.page_content_column)
yield Document(page_content=text, metadata=metadata)
def load(self) -> List[Document]:
"""Load from the dataframe."""
if self.df.count() > self.max_num_rows:
logger.warning(
f"The number of DataFrame rows is {self.df.count()}, "
f"but we will only include the amount "
f"of rows that can reasonably fit in memory: {self.num_rows}."
)
lazy_load_iterator = self.lazy_load()
return list(itertools.islice(lazy_load_iterator, self.num_rows))

37
poetry.lock generated

@ -6643,6 +6643,18 @@ pytz = "*"
requests = "*"
requests-oauthlib = ">=0.4.1"
[[package]]
name = "py4j"
version = "0.10.9.7"
description = "Enables Python programs to dynamically access arbitrary Java objects"
category = "main"
optional = true
python-versions = "*"
files = [
{file = "py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b"},
{file = "py4j-0.10.9.7.tar.gz", hash = "sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb"},
]
[[package]]
name = "pyaes"
version = "1.6.1"
@ -7229,6 +7241,27 @@ files = [
{file = "PySocks-1.7.1.tar.gz", hash = "sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0"},
]
[[package]]
name = "pyspark"
version = "3.4.0"
description = "Apache Spark Python API"
category = "main"
optional = true
python-versions = ">=3.7"
files = [
{file = "pyspark-3.4.0.tar.gz", hash = "sha256:167a23e11854adb37f8602de6fcc3a4f96fd5f1e323b9bb83325f38408c5aafd"},
]
[package.dependencies]
py4j = "0.10.9.7"
[package.extras]
connect = ["googleapis-common-protos (>=1.56.4)", "grpcio (>=1.48.1)", "grpcio-status (>=1.48.1)", "numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
ml = ["numpy (>=1.15)"]
mllib = ["numpy (>=1.15)"]
pandas-on-spark = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
sql = ["numpy (>=1.15)", "pandas (>=1.0.5)", "pyarrow (>=1.0.0)"]
[[package]]
name = "pytesseract"
version = "0.3.10"
@ -10920,7 +10953,7 @@ azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"]
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "requests-toolbelt", "scikit-learn", "telethon", "tqdm", "zep-python"]
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
openai = ["openai", "tiktoken"]
qdrant = ["qdrant-client"]
@ -10929,4 +10962,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "1033e47cdab7d3a15fb9322bad64609f77fd3befc47c1a01dc91b22cbbc708a3"
content-hash = "b3dc23f376de141d22b729d038144a1e6d66983a910160c3500fe0d79f8e5917"

@ -100,6 +100,7 @@ azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
py-trello = {version = "^0.19.0", optional = true}
momento = {version = "^1.5.0", optional = true}
bibtexparser = {version = "^1.4.0", optional = true}
pyspark = {version = "^3.4.0", optional = true}
[tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0"
@ -301,6 +302,7 @@ extended_testing = [
"html2text",
"py-trello",
"scikit-learn",
"pyspark",
]
[tool.ruff]

@ -0,0 +1,38 @@
import random
import string
from langchain.docstore.document import Document
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
def test_pyspark_loader_load_valid_data() -> None:
from pyspark.sql import SparkSession
# Requires a session to be set up
spark = SparkSession.builder.getOrCreate()
data = [
(random.choice(string.ascii_letters), random.randint(0, 1)) for _ in range(3)
]
df = spark.createDataFrame(data, ["text", "label"])
expected_docs = [
Document(
page_content=data[0][0],
metadata={"label": data[0][1]},
),
Document(
page_content=data[1][0],
metadata={"label": data[1][1]},
),
Document(
page_content=data[2][0],
metadata={"label": data[2][1]},
),
]
loader = PySparkDataFrameLoader(
spark_session=spark, df=df, page_content_column="text"
)
result = loader.load()
assert result == expected_docs
Loading…
Cancel
Save