diff --git a/langchain/document_loaders/snowflake_loader.py b/langchain/document_loaders/snowflake_loader.py index b76a7426..59164124 100644 --- a/langchain/document_loaders/snowflake_loader.py +++ b/langchain/document_loaders/snowflake_loader.py @@ -53,13 +53,14 @@ class SnowflakeLoader(BaseLoader): self.database = database self.schema = schema self.parameters = parameters - self.page_content_columns = page_content_columns - self.metadata_columns = metadata_columns + self.page_content_columns = ( + page_content_columns if page_content_columns is not None else ["*"] + ) + self.metadata_columns = metadata_columns if metadata_columns is not None else [] def _execute_query(self) -> List[Dict[str, Any]]: try: import snowflake.connector - from snowflake.connector import DictCursor except ImportError as ex: raise ValueError( "Could not import snowflake-connector-python package. " @@ -77,14 +78,13 @@ class SnowflakeLoader(BaseLoader): parameters=self.parameters, ) try: - cur = conn.cursor(DictCursor) + cur = conn.cursor() cur.execute("USE DATABASE " + self.database) cur.execute("USE SCHEMA " + self.schema) cur.execute(self.query, self.parameters) query_result = cur.fetchall() - query_result = [ - {k.lower(): v for k, v in item.items()} for item in query_result - ] + column_names = [column[0] for column in cur.description] + query_result = [dict(zip(column_names, row)) for row in query_result] except Exception as e: print(f"An error occurred: {e}") query_result = [] @@ -111,6 +111,8 @@ class SnowflakeLoader(BaseLoader): print(f"An error occurred during the query: {query_result}") return [] page_content_columns, metadata_columns = self._get_columns(query_result) + if "*" in page_content_columns: + page_content_columns = list(query_result[0].keys()) for row in query_result: page_content = "\n".join( f"{k}: {v}" for k, v in row.items() if k in page_content_columns @@ -120,4 +122,5 @@ class SnowflakeLoader(BaseLoader): yield doc def load(self) -> List[Document]: + """Load data into document objects.""" return list(self.lazy_load())