Add to support polars (#9610)

### Description
Polars is a DataFrame interface on top of an OLAP Query Engine
implemented in Rust.
Polars is faster to read than pandas, so I'm looking forward to seeing
it added to the document loader.

### Dependencies
polars (https://pola-rs.github.io/polars-book/user-guide/)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
toddkim95 2023-08-22 23:36:24 +09:00 committed by GitHub
parent 3c4f32c8b8
commit fba29f203a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 339 additions and 29 deletions

View File

@ -0,0 +1,225 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "213a38a2",
"metadata": {},
"source": [
"# Polars DataFrame\n",
"\n",
"This notebook goes over how to load data from a [polars](https://pola-rs.github.io/polars-book/user-guide/) DataFrame."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f6a7a9e4-80d6-486a-b2e3-636c568aa97c",
"metadata": {},
"outputs": [],
"source": [
"#!pip install polars"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "79331964",
"metadata": {},
"outputs": [],
"source": [
"import polars as pl"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e487044c",
"metadata": {},
"outputs": [],
"source": [
"df = pl.read_csv(\"example_data/mlb_teams_2012.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ac273ca1",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div><style>\n",
".dataframe > thead > tr > th,\n",
".dataframe > tbody > tr > td {\n",
" text-align: right;\n",
"}\n",
"</style>\n",
"<small>shape: (5, 3)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>Team</th><th> &quot;Payroll (millions)&quot;</th><th> &quot;Wins&quot;</th></tr><tr><td>str</td><td>f64</td><td>i64</td></tr></thead><tbody><tr><td>&quot;Nationals&quot;</td><td>81.34</td><td>98</td></tr><tr><td>&quot;Reds&quot;</td><td>82.2</td><td>97</td></tr><tr><td>&quot;Yankees&quot;</td><td>197.96</td><td>95</td></tr><tr><td>&quot;Giants&quot;</td><td>117.62</td><td>94</td></tr><tr><td>&quot;Braves&quot;</td><td>83.31</td><td>94</td></tr></tbody></table></div>"
],
"text/plain": [
"shape: (5, 3)\n",
"┌───────────┬───────────────────────┬─────────┐\n",
"│ Team ┆ \"Payroll (millions)\" ┆ \"Wins\" │\n",
"│ --- ┆ --- ┆ --- │\n",
"│ str ┆ f64 ┆ i64 │\n",
"╞═══════════╪═══════════════════════╪═════════╡\n",
"│ Nationals ┆ 81.34 ┆ 98 │\n",
"│ Reds ┆ 82.2 ┆ 97 │\n",
"│ Yankees ┆ 197.96 ┆ 95 │\n",
"│ Giants ┆ 117.62 ┆ 94 │\n",
"│ Braves ┆ 83.31 ┆ 94 │\n",
"└───────────┴───────────────────────┴─────────┘"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "66e47a13",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import PolarsDataFrameLoader"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2334caca",
"metadata": {},
"outputs": [],
"source": [
"loader = PolarsDataFrameLoader(df, page_content_column=\"Team\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d616c2b0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='Nationals', metadata={' \"Payroll (millions)\"': 81.34, ' \"Wins\"': 98}),\n",
" Document(page_content='Reds', metadata={' \"Payroll (millions)\"': 82.2, ' \"Wins\"': 97}),\n",
" Document(page_content='Yankees', metadata={' \"Payroll (millions)\"': 197.96, ' \"Wins\"': 95}),\n",
" Document(page_content='Giants', metadata={' \"Payroll (millions)\"': 117.62, ' \"Wins\"': 94}),\n",
" Document(page_content='Braves', metadata={' \"Payroll (millions)\"': 83.31, ' \"Wins\"': 94}),\n",
" Document(page_content='Athletics', metadata={' \"Payroll (millions)\"': 55.37, ' \"Wins\"': 94}),\n",
" Document(page_content='Rangers', metadata={' \"Payroll (millions)\"': 120.51, ' \"Wins\"': 93}),\n",
" Document(page_content='Orioles', metadata={' \"Payroll (millions)\"': 81.43, ' \"Wins\"': 93}),\n",
" Document(page_content='Rays', metadata={' \"Payroll (millions)\"': 64.17, ' \"Wins\"': 90}),\n",
" Document(page_content='Angels', metadata={' \"Payroll (millions)\"': 154.49, ' \"Wins\"': 89}),\n",
" Document(page_content='Tigers', metadata={' \"Payroll (millions)\"': 132.3, ' \"Wins\"': 88}),\n",
" Document(page_content='Cardinals', metadata={' \"Payroll (millions)\"': 110.3, ' \"Wins\"': 88}),\n",
" Document(page_content='Dodgers', metadata={' \"Payroll (millions)\"': 95.14, ' \"Wins\"': 86}),\n",
" Document(page_content='White Sox', metadata={' \"Payroll (millions)\"': 96.92, ' \"Wins\"': 85}),\n",
" Document(page_content='Brewers', metadata={' \"Payroll (millions)\"': 97.65, ' \"Wins\"': 83}),\n",
" Document(page_content='Phillies', metadata={' \"Payroll (millions)\"': 174.54, ' \"Wins\"': 81}),\n",
" Document(page_content='Diamondbacks', metadata={' \"Payroll (millions)\"': 74.28, ' \"Wins\"': 81}),\n",
" Document(page_content='Pirates', metadata={' \"Payroll (millions)\"': 63.43, ' \"Wins\"': 79}),\n",
" Document(page_content='Padres', metadata={' \"Payroll (millions)\"': 55.24, ' \"Wins\"': 76}),\n",
" Document(page_content='Mariners', metadata={' \"Payroll (millions)\"': 81.97, ' \"Wins\"': 75}),\n",
" Document(page_content='Mets', metadata={' \"Payroll (millions)\"': 93.35, ' \"Wins\"': 74}),\n",
" Document(page_content='Blue Jays', metadata={' \"Payroll (millions)\"': 75.48, ' \"Wins\"': 73}),\n",
" Document(page_content='Royals', metadata={' \"Payroll (millions)\"': 60.91, ' \"Wins\"': 72}),\n",
" Document(page_content='Marlins', metadata={' \"Payroll (millions)\"': 118.07, ' \"Wins\"': 69}),\n",
" Document(page_content='Red Sox', metadata={' \"Payroll (millions)\"': 173.18, ' \"Wins\"': 69}),\n",
" Document(page_content='Indians', metadata={' \"Payroll (millions)\"': 78.43, ' \"Wins\"': 68}),\n",
" Document(page_content='Twins', metadata={' \"Payroll (millions)\"': 94.08, ' \"Wins\"': 66}),\n",
" Document(page_content='Rockies', metadata={' \"Payroll (millions)\"': 78.06, ' \"Wins\"': 64}),\n",
" Document(page_content='Cubs', metadata={' \"Payroll (millions)\"': 88.19, ' \"Wins\"': 61}),\n",
" Document(page_content='Astros', metadata={' \"Payroll (millions)\"': 60.65, ' \"Wins\"': 55})]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loader.load()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "beb55c2f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"page_content='Nationals' metadata={' \"Payroll (millions)\"': 81.34, ' \"Wins\"': 98}\n",
"page_content='Reds' metadata={' \"Payroll (millions)\"': 82.2, ' \"Wins\"': 97}\n",
"page_content='Yankees' metadata={' \"Payroll (millions)\"': 197.96, ' \"Wins\"': 95}\n",
"page_content='Giants' metadata={' \"Payroll (millions)\"': 117.62, ' \"Wins\"': 94}\n",
"page_content='Braves' metadata={' \"Payroll (millions)\"': 83.31, ' \"Wins\"': 94}\n",
"page_content='Athletics' metadata={' \"Payroll (millions)\"': 55.37, ' \"Wins\"': 94}\n",
"page_content='Rangers' metadata={' \"Payroll (millions)\"': 120.51, ' \"Wins\"': 93}\n",
"page_content='Orioles' metadata={' \"Payroll (millions)\"': 81.43, ' \"Wins\"': 93}\n",
"page_content='Rays' metadata={' \"Payroll (millions)\"': 64.17, ' \"Wins\"': 90}\n",
"page_content='Angels' metadata={' \"Payroll (millions)\"': 154.49, ' \"Wins\"': 89}\n",
"page_content='Tigers' metadata={' \"Payroll (millions)\"': 132.3, ' \"Wins\"': 88}\n",
"page_content='Cardinals' metadata={' \"Payroll (millions)\"': 110.3, ' \"Wins\"': 88}\n",
"page_content='Dodgers' metadata={' \"Payroll (millions)\"': 95.14, ' \"Wins\"': 86}\n",
"page_content='White Sox' metadata={' \"Payroll (millions)\"': 96.92, ' \"Wins\"': 85}\n",
"page_content='Brewers' metadata={' \"Payroll (millions)\"': 97.65, ' \"Wins\"': 83}\n",
"page_content='Phillies' metadata={' \"Payroll (millions)\"': 174.54, ' \"Wins\"': 81}\n",
"page_content='Diamondbacks' metadata={' \"Payroll (millions)\"': 74.28, ' \"Wins\"': 81}\n",
"page_content='Pirates' metadata={' \"Payroll (millions)\"': 63.43, ' \"Wins\"': 79}\n",
"page_content='Padres' metadata={' \"Payroll (millions)\"': 55.24, ' \"Wins\"': 76}\n",
"page_content='Mariners' metadata={' \"Payroll (millions)\"': 81.97, ' \"Wins\"': 75}\n",
"page_content='Mets' metadata={' \"Payroll (millions)\"': 93.35, ' \"Wins\"': 74}\n",
"page_content='Blue Jays' metadata={' \"Payroll (millions)\"': 75.48, ' \"Wins\"': 73}\n",
"page_content='Royals' metadata={' \"Payroll (millions)\"': 60.91, ' \"Wins\"': 72}\n",
"page_content='Marlins' metadata={' \"Payroll (millions)\"': 118.07, ' \"Wins\"': 69}\n",
"page_content='Red Sox' metadata={' \"Payroll (millions)\"': 173.18, ' \"Wins\"': 69}\n",
"page_content='Indians' metadata={' \"Payroll (millions)\"': 78.43, ' \"Wins\"': 68}\n",
"page_content='Twins' metadata={' \"Payroll (millions)\"': 94.08, ' \"Wins\"': 66}\n",
"page_content='Rockies' metadata={' \"Payroll (millions)\"': 78.06, ' \"Wins\"': 64}\n",
"page_content='Cubs' metadata={' \"Payroll (millions)\"': 88.19, ' \"Wins\"': 61}\n",
"page_content='Astros' metadata={' \"Payroll (millions)\"': 60.65, ' \"Wins\"': 55}\n"
]
}
],
"source": [
"# Use lazy load for larger table, which won't read the full table into memory\n",
"for i in loader.lazy_load():\n",
" print(i)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -132,6 +132,7 @@ from langchain.document_loaders.pdf import (
PyPDFLoader,
UnstructuredPDFLoader,
)
from langchain.document_loaders.polars_dataframe import PolarsDataFrameLoader
from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader
from langchain.document_loaders.psychic import PsychicLoader
from langchain.document_loaders.pubmed import PubMedLoader
@ -299,6 +300,7 @@ __all__ = [
"PDFPlumberLoader",
"PagedPDFSplitter",
"PlaywrightURLLoader",
"PolarsDataFrameLoader",
"PsychicLoader",
"PubMedLoader",
"PyMuPDFLoader",

View File

@ -4,23 +4,15 @@ from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
class DataFrameLoader(BaseLoader):
"""Load `Pandas` DataFrame."""
def __init__(self, data_frame: Any, page_content_column: str = "text"):
class BaseDataFrameLoader(BaseLoader):
def __init__(self, data_frame: Any, *, page_content_column: str = "text"):
"""Initialize with dataframe object.
Args:
data_frame: Pandas DataFrame object.
data_frame: DataFrame object.
page_content_column: Name of the column containing the page content.
Defaults to "text".
"""
import pandas as pd
if not isinstance(data_frame, pd.DataFrame):
raise ValueError(
f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}"
)
self.data_frame = data_frame
self.page_content_column = page_content_column
@ -36,3 +28,28 @@ class DataFrameLoader(BaseLoader):
def load(self) -> List[Document]:
"""Load full dataframe."""
return list(self.lazy_load())
class DataFrameLoader(BaseDataFrameLoader):
"""Load `Pandas` DataFrame."""
def __init__(self, data_frame: Any, page_content_column: str = "text"):
"""Initialize with dataframe object.
Args:
data_frame: Pandas DataFrame object.
page_content_column: Name of the column containing the page content.
Defaults to "text".
"""
try:
import pandas as pd
except ImportError as e:
raise ImportError(
"Unable to import pandas, please install with `pip install pandas`."
) from e
if not isinstance(data_frame, pd.DataFrame):
raise ValueError(
f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}"
)
super().__init__(data_frame, page_content_column=page_content_column)

View File

@ -0,0 +1,32 @@
from typing import Any, Iterator
from langchain.docstore.document import Document
from langchain.document_loaders.dataframe import BaseDataFrameLoader
class PolarsDataFrameLoader(BaseDataFrameLoader):
"""Load `Polars` DataFrame."""
def __init__(self, data_frame: Any, *, page_content_column: str = "text"):
"""Initialize with dataframe object.
Args:
data_frame: Polars DataFrame object.
page_content_column: Name of the column containing the page content.
Defaults to "text".
"""
import polars as pl
if not isinstance(data_frame, pl.DataFrame):
raise ValueError(
f"Expected data_frame to be a pl.DataFrame, got {type(data_frame)}"
)
super().__init__(data_frame, page_content_column=page_content_column)
def lazy_load(self) -> Iterator[Document]:
"""Lazy load records from dataframe."""
for row in self.data_frame.iter_rows(named=True):
text = row[self.page_content_column]
row.pop(self.page_content_column)
yield Document(page_content=text, metadata=row)

View File

@ -1,10 +1,9 @@
from typing import Any, Iterator, List
from typing import Any
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.dataframe import BaseDataFrameLoader
class XorbitsLoader(BaseLoader):
class XorbitsLoader(BaseDataFrameLoader):
"""Load `Xorbits` DataFrame."""
def __init__(self, data_frame: Any, page_content_column: str = "text"):
@ -30,17 +29,4 @@ class XorbitsLoader(BaseLoader):
f"Expected data_frame to be a xorbits.pandas.DataFrame, \
got {type(data_frame)}"
)
self.data_frame = data_frame
self.page_content_column = page_content_column
def lazy_load(self) -> Iterator[Document]:
"""Lazy load records from dataframe."""
for _, row in self.data_frame.iterrows():
text = row[self.page_content_column]
metadata = row.to_dict()
metadata.pop(self.page_content_column)
yield Document(page_content=text, metadata=metadata)
def load(self) -> List[Document]:
"""Load full dataframe."""
return list(self.lazy_load())
super().__init__(data_frame, page_content_column=page_content_column)

View File

@ -0,0 +1,48 @@
import polars as pl
import pytest
from langchain.document_loaders import PolarsDataFrameLoader
from langchain.schema import Document
@pytest.fixture
def sample_data_frame() -> pl.DataFrame:
data = {
"text": ["Hello", "World"],
"author": ["Alice", "Bob"],
"date": ["2022-01-01", "2022-01-02"],
}
return pl.DataFrame(data)
def test_load_returns_list_of_documents(sample_data_frame: pl.DataFrame) -> None:
loader = PolarsDataFrameLoader(sample_data_frame)
docs = loader.load()
assert isinstance(docs, list)
assert all(isinstance(doc, Document) for doc in docs)
assert len(docs) == 2
def test_load_converts_dataframe_columns_to_document_metadata(
sample_data_frame: pl.DataFrame,
) -> None:
loader = PolarsDataFrameLoader(sample_data_frame)
docs = loader.load()
for i, doc in enumerate(docs):
df: pl.DataFrame = sample_data_frame[i]
assert df is not None
assert doc.metadata["author"] == df.select("author").item()
assert doc.metadata["date"] == df.select("date").item()
def test_load_uses_page_content_column_to_create_document_text(
sample_data_frame: pl.DataFrame,
) -> None:
sample_data_frame = sample_data_frame.rename(mapping={"text": "dummy_test_column"})
loader = PolarsDataFrameLoader(
sample_data_frame, page_content_column="dummy_test_column"
)
docs = loader.load()
assert docs[0].page_content == "Hello"
assert docs[1].page_content == "World"