From 2649b638dd36a0786a4f5368713242324c611bcb Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Tue, 30 May 2023 10:42:20 -0700 Subject: [PATCH] fix (#5457) --- langchain/document_loaders/pyspark_dataframe.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/langchain/document_loaders/pyspark_dataframe.py b/langchain/document_loaders/pyspark_dataframe.py index fe8fe4e7..c1f186cd 100644 --- a/langchain/document_loaders/pyspark_dataframe.py +++ b/langchain/document_loaders/pyspark_dataframe.py @@ -4,8 +4,6 @@ import logging import sys from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple -import psutil - from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader @@ -29,7 +27,7 @@ class PySparkDataFrameLoader(BaseLoader): try: from pyspark.sql import DataFrame, SparkSession except ImportError: - raise ValueError( + raise ImportError( "pyspark is not installed. " "Please install it with `pip install pyspark`" ) @@ -51,6 +49,12 @@ class PySparkDataFrameLoader(BaseLoader): def get_num_rows(self) -> Tuple[int, int]: """Gets the amount of "feasible" rows for the DataFrame""" + try: + import psutil + except ImportError as e: + raise ImportError( + "psutil not installed. Please install it with `pip install psutil`." + ) from e row = self.df.limit(1).collect()[0] estimated_row_size = sys.getsizeof(row) mem_info = psutil.virtual_memory()