forked from Archives/langchain
fix (#5457)
This commit is contained in:
parent
64b4165c8d
commit
2649b638dd
@ -4,8 +4,6 @@ import logging
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
@ -29,7 +27,7 @@ class PySparkDataFrameLoader(BaseLoader):
|
||||
try:
|
||||
from pyspark.sql import DataFrame, SparkSession
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"pyspark is not installed. "
|
||||
"Please install it with `pip install pyspark`"
|
||||
)
|
||||
@ -51,6 +49,12 @@ class PySparkDataFrameLoader(BaseLoader):
|
||||
|
||||
def get_num_rows(self) -> Tuple[int, int]:
|
||||
"""Gets the amount of "feasible" rows for the DataFrame"""
|
||||
try:
|
||||
import psutil
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"psutil not installed. Please install it with `pip install psutil`."
|
||||
) from e
|
||||
row = self.df.limit(1).collect()[0]
|
||||
estimated_row_size = sys.getsizeof(row)
|
||||
mem_info = psutil.virtual_memory()
|
||||
|
Loading…
Reference in New Issue
Block a user