forked from Archives/langchain
Harrison/spark reader (#5405)
Co-authored-by: Rithwik Ediga Lakhamsani <rithwik.ediga@databricks.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>searx_updates
parent
8259f9b7fa
commit
760632b292
@ -0,0 +1,97 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# PySpack DataFrame Loader\n",
|
||||||
|
"\n",
|
||||||
|
"This shows how to load data from a PySpark DataFrame"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#!pip install pyspark"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from pyspark.sql import SparkSession"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"spark = SparkSession.builder.getOrCreate()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"df = spark.read.csv('example_data/mlb_teams_2012.csv', header=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.document_loaders import PySparkDataFrameLoader"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"loader = PySparkDataFrameLoader(spark, df, page_content_column=\"Team\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"loader.load()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -0,0 +1,80 @@
|
|||||||
|
"""Load from a Spark Dataframe object"""
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.document_loaders.base import BaseLoader
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
|
|
||||||
|
|
||||||
|
class PySparkDataFrameLoader(BaseLoader):
|
||||||
|
"""Load PySpark DataFrames"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spark_session: Optional["SparkSession"] = None,
|
||||||
|
df: Optional[Any] = None,
|
||||||
|
page_content_column: str = "text",
|
||||||
|
fraction_of_memory: float = 0.1,
|
||||||
|
):
|
||||||
|
"""Initialize with a Spark DataFrame object."""
|
||||||
|
try:
|
||||||
|
from pyspark.sql import DataFrame, SparkSession
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"pyspark is not installed. "
|
||||||
|
"Please install it with `pip install pyspark`"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.spark = (
|
||||||
|
spark_session if spark_session else SparkSession.builder.getOrCreate()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(df, DataFrame):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected data_frame to be a PySpark DataFrame, got {type(df)}"
|
||||||
|
)
|
||||||
|
self.df = df
|
||||||
|
self.page_content_column = page_content_column
|
||||||
|
self.fraction_of_memory = fraction_of_memory
|
||||||
|
self.num_rows, self.max_num_rows = self.get_num_rows()
|
||||||
|
self.rdd_df = self.df.rdd.map(list)
|
||||||
|
self.column_names = self.df.columns
|
||||||
|
|
||||||
|
def get_num_rows(self) -> Tuple[int, int]:
|
||||||
|
"""Gets the amount of "feasible" rows for the DataFrame"""
|
||||||
|
row = self.df.limit(1).collect()[0]
|
||||||
|
estimated_row_size = sys.getsizeof(row)
|
||||||
|
mem_info = psutil.virtual_memory()
|
||||||
|
available_memory = mem_info.available
|
||||||
|
max_num_rows = int(
|
||||||
|
(available_memory / estimated_row_size) * self.fraction_of_memory
|
||||||
|
)
|
||||||
|
return min(max_num_rows, self.df.count()), max_num_rows
|
||||||
|
|
||||||
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
|
"""A lazy loader for document content."""
|
||||||
|
for row in self.rdd_df.toLocalIterator():
|
||||||
|
metadata = {self.column_names[i]: row[i] for i in range(len(row))}
|
||||||
|
text = metadata[self.page_content_column]
|
||||||
|
metadata.pop(self.page_content_column)
|
||||||
|
yield Document(page_content=text, metadata=metadata)
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
"""Load from the dataframe."""
|
||||||
|
if self.df.count() > self.max_num_rows:
|
||||||
|
logger.warning(
|
||||||
|
f"The number of DataFrame rows is {self.df.count()}, "
|
||||||
|
f"but we will only include the amount "
|
||||||
|
f"of rows that can reasonably fit in memory: {self.num_rows}."
|
||||||
|
)
|
||||||
|
lazy_load_iterator = self.lazy_load()
|
||||||
|
return list(itertools.islice(lazy_load_iterator, self.num_rows))
|
@ -0,0 +1,38 @@
|
|||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
|
||||||
|
|
||||||
|
|
||||||
|
def test_pyspark_loader_load_valid_data() -> None:
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
|
|
||||||
|
# Requires a session to be set up
|
||||||
|
spark = SparkSession.builder.getOrCreate()
|
||||||
|
data = [
|
||||||
|
(random.choice(string.ascii_letters), random.randint(0, 1)) for _ in range(3)
|
||||||
|
]
|
||||||
|
df = spark.createDataFrame(data, ["text", "label"])
|
||||||
|
|
||||||
|
expected_docs = [
|
||||||
|
Document(
|
||||||
|
page_content=data[0][0],
|
||||||
|
metadata={"label": data[0][1]},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content=data[1][0],
|
||||||
|
metadata={"label": data[1][1]},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
page_content=data[2][0],
|
||||||
|
metadata={"label": data[2][1]},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
loader = PySparkDataFrameLoader(
|
||||||
|
spark_session=spark, df=df, page_content_column="text"
|
||||||
|
)
|
||||||
|
result = loader.load()
|
||||||
|
|
||||||
|
assert result == expected_docs
|
Loading…
Reference in New Issue