mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
760632b292
Co-authored-by: Rithwik Ediga Lakhamsani <rithwik.ediga@databricks.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
39 lines
1.0 KiB
Python
39 lines
1.0 KiB
Python
import random
|
|
import string
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.document_loaders.pyspark_dataframe import PySparkDataFrameLoader
|
|
|
|
|
|
def test_pyspark_loader_load_valid_data() -> None:
|
|
from pyspark.sql import SparkSession
|
|
|
|
# Requires a session to be set up
|
|
spark = SparkSession.builder.getOrCreate()
|
|
data = [
|
|
(random.choice(string.ascii_letters), random.randint(0, 1)) for _ in range(3)
|
|
]
|
|
df = spark.createDataFrame(data, ["text", "label"])
|
|
|
|
expected_docs = [
|
|
Document(
|
|
page_content=data[0][0],
|
|
metadata={"label": data[0][1]},
|
|
),
|
|
Document(
|
|
page_content=data[1][0],
|
|
metadata={"label": data[1][1]},
|
|
),
|
|
Document(
|
|
page_content=data[2][0],
|
|
metadata={"label": data[2][1]},
|
|
),
|
|
]
|
|
|
|
loader = PySparkDataFrameLoader(
|
|
spark_session=spark, df=df, page_content_column="text"
|
|
)
|
|
result = loader.load()
|
|
|
|
assert result == expected_docs
|