mirror of https://github.com/hwchase17/langchain
Harrison/spark reader (#5405)
Co-authored-by: Rithwik Ediga Lakhamsani <rithwik.ediga@databricks.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>pull/5425/head
parent
8259f9b7fa
commit
760632b292
@ -0,0 +1,97 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# PySpack DataFrame Loader\n",
|
||||
"\n",
|
||||
"This shows how to load data from a PySpark DataFrame"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install pyspark"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pyspark.sql import SparkSession"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"spark = SparkSession.builder.getOrCreate()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = spark.read.csv('example_data/mlb_teams_2012.csv', header=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import PySparkDataFrameLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = PySparkDataFrameLoader(spark, df, page_content_column=\"Team\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader.load()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
"""Load from a Spark Dataframe object"""
|
||||
import itertools
|
||||
import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
|
||||
class PySparkDataFrameLoader(BaseLoader):
|
||||
"""Load PySpark DataFrames"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spark_session: Optional["SparkSession"] = None,
|
||||
df: Optional[Any] = None,
|
||||
page_content_column: str = "text",
|
||||
fraction_of_memory: float = 0.1,
|
||||
):
|
||||
"""Initialize with a Spark DataFrame object."""
|
||||
try:
|
||||
from pyspark.sql import DataFrame, SparkSession
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"pyspark is not installed. "
|
||||
"Please install it with `pip install pyspark`"
|
||||
)
|
||||
|
||||
self.spark = (
|
||||
spark_session if spark_session else SparkSession.builder.getOrCreate()
|
||||
)
|
||||
|
||||
if not isinstance(df, DataFrame):
|
||||
raise ValueError(
|
||||
f"Expected data_frame to be a PySpark DataFrame, got {type(df)}"
|
||||
)
|
||||
self.df = df
|
||||
self.page_content_column = page_content_column
|
||||
self.fraction_of_memory = fraction_of_memory
|
||||
self.num_rows, self.max_num_rows = self.get_num_rows()
|
||||
self.rdd_df = self.df.rdd.map(list)
|
||||
self.column_names = self.df.columns
|
||||
|
||||
def get_num_rows(self) -> Tuple[int, int]:
|
||||
"""Gets the amount of "feasible" rows for the DataFrame"""
|
||||
row = self.df.limit(1).collect()[0]
|
||||
estimated_row_size = sys.getsizeof(row)
|
||||
mem_info = psutil.virtual_memory()
|
||||
available_memory = mem_info.available
|
||||
max_num_rows = int(
|
||||
(available_memory / estimated_row_size) * self.fraction_of_memory
|
||||
)
|
||||
return min(max_num_rows, self.df.count()), max_num_rows
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""A lazy loader for document content."""
|
||||
for row in self.rdd_df.toLocalIterator():
|
||||
metadata = {self.column_names[i]: row[i] for i in range(len(row))}
|
||||
text = metadata[self.page_content_column]
|
||||
metadata.pop(self.page_content_column)
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from the dataframe."""
|
||||
if self.df.count() > self.max_num_rows:
|
||||
logger.warning(
|
||||
f"The number of DataFrame rows is {self.df.count()}, "
|
||||
f"but we will only include the amount "
|
||||
f"of rows that can reasonably fit in memory: {self.num_rows}."
|
||||
)
|
||||
lazy_load_iterator = self.lazy_load()
|
||||
return list(itertools.islice(lazy_load_iterator, self.num_rows))
|
@ -0,0 +1,38 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
|
||||
|
||||
|
||||
def test_pyspark_loader_load_valid_data() -> None:
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
# Requires a session to be set up
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
data = [
|
||||
(random.choice(string.ascii_letters), random.randint(0, 1)) for _ in range(3)
|
||||
]
|
||||
df = spark.createDataFrame(data, ["text", "label"])
|
||||
|
||||
expected_docs = [
|
||||
Document(
|
||||
page_content=data[0][0],
|
||||
metadata={"label": data[0][1]},
|
||||
),
|
||||
Document(
|
||||
page_content=data[1][0],
|
||||
metadata={"label": data[1][1]},
|
||||
),
|
||||
Document(
|
||||
page_content=data[2][0],
|
||||
metadata={"label": data[2][1]},
|
||||
),
|
||||
]
|
||||
|
||||
loader = PySparkDataFrameLoader(
|
||||
spark_session=spark, df=df, page_content_column="text"
|
||||
)
|
||||
result = loader.load()
|
||||
|
||||
assert result == expected_docs
|
Loading…
Reference in New Issue