mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add to support polars (#9610)
### Description Polars is a DataFrame interface on top of an OLAP Query Engine implemented in Rust. Polars is faster to read than pandas, so I'm looking forward to seeing it added to the document loader. ### Dependencies polars (https://pola-rs.github.io/polars-book/user-guide/) --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
3c4f32c8b8
commit
fba29f203a
225
docs/extras/integrations/document_loaders/polars_dataframe.ipynb
Normal file
225
docs/extras/integrations/document_loaders/polars_dataframe.ipynb
Normal file
@ -0,0 +1,225 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "213a38a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Polars DataFrame\n",
|
||||
"\n",
|
||||
"This notebook goes over how to load data from a [polars](https://pola-rs.github.io/polars-book/user-guide/) DataFrame."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "f6a7a9e4-80d6-486a-b2e3-636c568aa97c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install polars"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "79331964",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import polars as pl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "e487044c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = pl.read_csv(\"example_data/mlb_teams_2012.csv\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "ac273ca1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div><style>\n",
|
||||
".dataframe > thead > tr > th,\n",
|
||||
".dataframe > tbody > tr > td {\n",
|
||||
" text-align: right;\n",
|
||||
"}\n",
|
||||
"</style>\n",
|
||||
"<small>shape: (5, 3)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>Team</th><th> "Payroll (millions)"</th><th> "Wins"</th></tr><tr><td>str</td><td>f64</td><td>i64</td></tr></thead><tbody><tr><td>"Nationals"</td><td>81.34</td><td>98</td></tr><tr><td>"Reds"</td><td>82.2</td><td>97</td></tr><tr><td>"Yankees"</td><td>197.96</td><td>95</td></tr><tr><td>"Giants"</td><td>117.62</td><td>94</td></tr><tr><td>"Braves"</td><td>83.31</td><td>94</td></tr></tbody></table></div>"
|
||||
],
|
||||
"text/plain": [
|
||||
"shape: (5, 3)\n",
|
||||
"┌───────────┬───────────────────────┬─────────┐\n",
|
||||
"│ Team ┆ \"Payroll (millions)\" ┆ \"Wins\" │\n",
|
||||
"│ --- ┆ --- ┆ --- │\n",
|
||||
"│ str ┆ f64 ┆ i64 │\n",
|
||||
"╞═══════════╪═══════════════════════╪═════════╡\n",
|
||||
"│ Nationals ┆ 81.34 ┆ 98 │\n",
|
||||
"│ Reds ┆ 82.2 ┆ 97 │\n",
|
||||
"│ Yankees ┆ 197.96 ┆ 95 │\n",
|
||||
"│ Giants ┆ 117.62 ┆ 94 │\n",
|
||||
"│ Braves ┆ 83.31 ┆ 94 │\n",
|
||||
"└───────────┴───────────────────────┴─────────┘"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "66e47a13",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import PolarsDataFrameLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "2334caca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"loader = PolarsDataFrameLoader(df, page_content_column=\"Team\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "d616c2b0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Nationals', metadata={' \"Payroll (millions)\"': 81.34, ' \"Wins\"': 98}),\n",
|
||||
" Document(page_content='Reds', metadata={' \"Payroll (millions)\"': 82.2, ' \"Wins\"': 97}),\n",
|
||||
" Document(page_content='Yankees', metadata={' \"Payroll (millions)\"': 197.96, ' \"Wins\"': 95}),\n",
|
||||
" Document(page_content='Giants', metadata={' \"Payroll (millions)\"': 117.62, ' \"Wins\"': 94}),\n",
|
||||
" Document(page_content='Braves', metadata={' \"Payroll (millions)\"': 83.31, ' \"Wins\"': 94}),\n",
|
||||
" Document(page_content='Athletics', metadata={' \"Payroll (millions)\"': 55.37, ' \"Wins\"': 94}),\n",
|
||||
" Document(page_content='Rangers', metadata={' \"Payroll (millions)\"': 120.51, ' \"Wins\"': 93}),\n",
|
||||
" Document(page_content='Orioles', metadata={' \"Payroll (millions)\"': 81.43, ' \"Wins\"': 93}),\n",
|
||||
" Document(page_content='Rays', metadata={' \"Payroll (millions)\"': 64.17, ' \"Wins\"': 90}),\n",
|
||||
" Document(page_content='Angels', metadata={' \"Payroll (millions)\"': 154.49, ' \"Wins\"': 89}),\n",
|
||||
" Document(page_content='Tigers', metadata={' \"Payroll (millions)\"': 132.3, ' \"Wins\"': 88}),\n",
|
||||
" Document(page_content='Cardinals', metadata={' \"Payroll (millions)\"': 110.3, ' \"Wins\"': 88}),\n",
|
||||
" Document(page_content='Dodgers', metadata={' \"Payroll (millions)\"': 95.14, ' \"Wins\"': 86}),\n",
|
||||
" Document(page_content='White Sox', metadata={' \"Payroll (millions)\"': 96.92, ' \"Wins\"': 85}),\n",
|
||||
" Document(page_content='Brewers', metadata={' \"Payroll (millions)\"': 97.65, ' \"Wins\"': 83}),\n",
|
||||
" Document(page_content='Phillies', metadata={' \"Payroll (millions)\"': 174.54, ' \"Wins\"': 81}),\n",
|
||||
" Document(page_content='Diamondbacks', metadata={' \"Payroll (millions)\"': 74.28, ' \"Wins\"': 81}),\n",
|
||||
" Document(page_content='Pirates', metadata={' \"Payroll (millions)\"': 63.43, ' \"Wins\"': 79}),\n",
|
||||
" Document(page_content='Padres', metadata={' \"Payroll (millions)\"': 55.24, ' \"Wins\"': 76}),\n",
|
||||
" Document(page_content='Mariners', metadata={' \"Payroll (millions)\"': 81.97, ' \"Wins\"': 75}),\n",
|
||||
" Document(page_content='Mets', metadata={' \"Payroll (millions)\"': 93.35, ' \"Wins\"': 74}),\n",
|
||||
" Document(page_content='Blue Jays', metadata={' \"Payroll (millions)\"': 75.48, ' \"Wins\"': 73}),\n",
|
||||
" Document(page_content='Royals', metadata={' \"Payroll (millions)\"': 60.91, ' \"Wins\"': 72}),\n",
|
||||
" Document(page_content='Marlins', metadata={' \"Payroll (millions)\"': 118.07, ' \"Wins\"': 69}),\n",
|
||||
" Document(page_content='Red Sox', metadata={' \"Payroll (millions)\"': 173.18, ' \"Wins\"': 69}),\n",
|
||||
" Document(page_content='Indians', metadata={' \"Payroll (millions)\"': 78.43, ' \"Wins\"': 68}),\n",
|
||||
" Document(page_content='Twins', metadata={' \"Payroll (millions)\"': 94.08, ' \"Wins\"': 66}),\n",
|
||||
" Document(page_content='Rockies', metadata={' \"Payroll (millions)\"': 78.06, ' \"Wins\"': 64}),\n",
|
||||
" Document(page_content='Cubs', metadata={' \"Payroll (millions)\"': 88.19, ' \"Wins\"': 61}),\n",
|
||||
" Document(page_content='Astros', metadata={' \"Payroll (millions)\"': 60.65, ' \"Wins\"': 55})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"loader.load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "beb55c2f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"page_content='Nationals' metadata={' \"Payroll (millions)\"': 81.34, ' \"Wins\"': 98}\n",
|
||||
"page_content='Reds' metadata={' \"Payroll (millions)\"': 82.2, ' \"Wins\"': 97}\n",
|
||||
"page_content='Yankees' metadata={' \"Payroll (millions)\"': 197.96, ' \"Wins\"': 95}\n",
|
||||
"page_content='Giants' metadata={' \"Payroll (millions)\"': 117.62, ' \"Wins\"': 94}\n",
|
||||
"page_content='Braves' metadata={' \"Payroll (millions)\"': 83.31, ' \"Wins\"': 94}\n",
|
||||
"page_content='Athletics' metadata={' \"Payroll (millions)\"': 55.37, ' \"Wins\"': 94}\n",
|
||||
"page_content='Rangers' metadata={' \"Payroll (millions)\"': 120.51, ' \"Wins\"': 93}\n",
|
||||
"page_content='Orioles' metadata={' \"Payroll (millions)\"': 81.43, ' \"Wins\"': 93}\n",
|
||||
"page_content='Rays' metadata={' \"Payroll (millions)\"': 64.17, ' \"Wins\"': 90}\n",
|
||||
"page_content='Angels' metadata={' \"Payroll (millions)\"': 154.49, ' \"Wins\"': 89}\n",
|
||||
"page_content='Tigers' metadata={' \"Payroll (millions)\"': 132.3, ' \"Wins\"': 88}\n",
|
||||
"page_content='Cardinals' metadata={' \"Payroll (millions)\"': 110.3, ' \"Wins\"': 88}\n",
|
||||
"page_content='Dodgers' metadata={' \"Payroll (millions)\"': 95.14, ' \"Wins\"': 86}\n",
|
||||
"page_content='White Sox' metadata={' \"Payroll (millions)\"': 96.92, ' \"Wins\"': 85}\n",
|
||||
"page_content='Brewers' metadata={' \"Payroll (millions)\"': 97.65, ' \"Wins\"': 83}\n",
|
||||
"page_content='Phillies' metadata={' \"Payroll (millions)\"': 174.54, ' \"Wins\"': 81}\n",
|
||||
"page_content='Diamondbacks' metadata={' \"Payroll (millions)\"': 74.28, ' \"Wins\"': 81}\n",
|
||||
"page_content='Pirates' metadata={' \"Payroll (millions)\"': 63.43, ' \"Wins\"': 79}\n",
|
||||
"page_content='Padres' metadata={' \"Payroll (millions)\"': 55.24, ' \"Wins\"': 76}\n",
|
||||
"page_content='Mariners' metadata={' \"Payroll (millions)\"': 81.97, ' \"Wins\"': 75}\n",
|
||||
"page_content='Mets' metadata={' \"Payroll (millions)\"': 93.35, ' \"Wins\"': 74}\n",
|
||||
"page_content='Blue Jays' metadata={' \"Payroll (millions)\"': 75.48, ' \"Wins\"': 73}\n",
|
||||
"page_content='Royals' metadata={' \"Payroll (millions)\"': 60.91, ' \"Wins\"': 72}\n",
|
||||
"page_content='Marlins' metadata={' \"Payroll (millions)\"': 118.07, ' \"Wins\"': 69}\n",
|
||||
"page_content='Red Sox' metadata={' \"Payroll (millions)\"': 173.18, ' \"Wins\"': 69}\n",
|
||||
"page_content='Indians' metadata={' \"Payroll (millions)\"': 78.43, ' \"Wins\"': 68}\n",
|
||||
"page_content='Twins' metadata={' \"Payroll (millions)\"': 94.08, ' \"Wins\"': 66}\n",
|
||||
"page_content='Rockies' metadata={' \"Payroll (millions)\"': 78.06, ' \"Wins\"': 64}\n",
|
||||
"page_content='Cubs' metadata={' \"Payroll (millions)\"': 88.19, ' \"Wins\"': 61}\n",
|
||||
"page_content='Astros' metadata={' \"Payroll (millions)\"': 60.65, ' \"Wins\"': 55}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Use lazy load for larger table, which won't read the full table into memory\n",
|
||||
"for i in loader.lazy_load():\n",
|
||||
" print(i)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -132,6 +132,7 @@ from langchain.document_loaders.pdf import (
|
||||
PyPDFLoader,
|
||||
UnstructuredPDFLoader,
|
||||
)
|
||||
from langchain.document_loaders.polars_dataframe import PolarsDataFrameLoader
|
||||
from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader
|
||||
from langchain.document_loaders.psychic import PsychicLoader
|
||||
from langchain.document_loaders.pubmed import PubMedLoader
|
||||
@ -299,6 +300,7 @@ __all__ = [
|
||||
"PDFPlumberLoader",
|
||||
"PagedPDFSplitter",
|
||||
"PlaywrightURLLoader",
|
||||
"PolarsDataFrameLoader",
|
||||
"PsychicLoader",
|
||||
"PubMedLoader",
|
||||
"PyMuPDFLoader",
|
||||
|
@ -4,23 +4,15 @@ from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class DataFrameLoader(BaseLoader):
|
||||
"""Load `Pandas` DataFrame."""
|
||||
|
||||
def __init__(self, data_frame: Any, page_content_column: str = "text"):
|
||||
class BaseDataFrameLoader(BaseLoader):
|
||||
def __init__(self, data_frame: Any, *, page_content_column: str = "text"):
|
||||
"""Initialize with dataframe object.
|
||||
|
||||
Args:
|
||||
data_frame: Pandas DataFrame object.
|
||||
data_frame: DataFrame object.
|
||||
page_content_column: Name of the column containing the page content.
|
||||
Defaults to "text".
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
if not isinstance(data_frame, pd.DataFrame):
|
||||
raise ValueError(
|
||||
f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}"
|
||||
)
|
||||
self.data_frame = data_frame
|
||||
self.page_content_column = page_content_column
|
||||
|
||||
@ -36,3 +28,28 @@ class DataFrameLoader(BaseLoader):
|
||||
def load(self) -> List[Document]:
|
||||
"""Load full dataframe."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
|
||||
class DataFrameLoader(BaseDataFrameLoader):
|
||||
"""Load `Pandas` DataFrame."""
|
||||
|
||||
def __init__(self, data_frame: Any, page_content_column: str = "text"):
|
||||
"""Initialize with dataframe object.
|
||||
|
||||
Args:
|
||||
data_frame: Pandas DataFrame object.
|
||||
page_content_column: Name of the column containing the page content.
|
||||
Defaults to "text".
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import pandas, please install with `pip install pandas`."
|
||||
) from e
|
||||
|
||||
if not isinstance(data_frame, pd.DataFrame):
|
||||
raise ValueError(
|
||||
f"Expected data_frame to be a pd.DataFrame, got {type(data_frame)}"
|
||||
)
|
||||
super().__init__(data_frame, page_content_column=page_content_column)
|
||||
|
@ -0,0 +1,32 @@
|
||||
from typing import Any, Iterator
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.dataframe import BaseDataFrameLoader
|
||||
|
||||
|
||||
class PolarsDataFrameLoader(BaseDataFrameLoader):
|
||||
"""Load `Polars` DataFrame."""
|
||||
|
||||
def __init__(self, data_frame: Any, *, page_content_column: str = "text"):
|
||||
"""Initialize with dataframe object.
|
||||
|
||||
Args:
|
||||
data_frame: Polars DataFrame object.
|
||||
page_content_column: Name of the column containing the page content.
|
||||
Defaults to "text".
|
||||
"""
|
||||
import polars as pl
|
||||
|
||||
if not isinstance(data_frame, pl.DataFrame):
|
||||
raise ValueError(
|
||||
f"Expected data_frame to be a pl.DataFrame, got {type(data_frame)}"
|
||||
)
|
||||
super().__init__(data_frame, page_content_column=page_content_column)
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load records from dataframe."""
|
||||
|
||||
for row in self.data_frame.iter_rows(named=True):
|
||||
text = row[self.page_content_column]
|
||||
row.pop(self.page_content_column)
|
||||
yield Document(page_content=text, metadata=row)
|
@ -1,10 +1,9 @@
|
||||
from typing import Any, Iterator, List
|
||||
from typing import Any
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.dataframe import BaseDataFrameLoader
|
||||
|
||||
|
||||
class XorbitsLoader(BaseLoader):
|
||||
class XorbitsLoader(BaseDataFrameLoader):
|
||||
"""Load `Xorbits` DataFrame."""
|
||||
|
||||
def __init__(self, data_frame: Any, page_content_column: str = "text"):
|
||||
@ -30,17 +29,4 @@ class XorbitsLoader(BaseLoader):
|
||||
f"Expected data_frame to be a xorbits.pandas.DataFrame, \
|
||||
got {type(data_frame)}"
|
||||
)
|
||||
self.data_frame = data_frame
|
||||
self.page_content_column = page_content_column
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load records from dataframe."""
|
||||
for _, row in self.data_frame.iterrows():
|
||||
text = row[self.page_content_column]
|
||||
metadata = row.to_dict()
|
||||
metadata.pop(self.page_content_column)
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load full dataframe."""
|
||||
return list(self.lazy_load())
|
||||
super().__init__(data_frame, page_content_column=page_content_column)
|
||||
|
@ -0,0 +1,48 @@
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from langchain.document_loaders import PolarsDataFrameLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data_frame() -> pl.DataFrame:
|
||||
data = {
|
||||
"text": ["Hello", "World"],
|
||||
"author": ["Alice", "Bob"],
|
||||
"date": ["2022-01-01", "2022-01-02"],
|
||||
}
|
||||
return pl.DataFrame(data)
|
||||
|
||||
|
||||
def test_load_returns_list_of_documents(sample_data_frame: pl.DataFrame) -> None:
|
||||
loader = PolarsDataFrameLoader(sample_data_frame)
|
||||
docs = loader.load()
|
||||
assert isinstance(docs, list)
|
||||
assert all(isinstance(doc, Document) for doc in docs)
|
||||
assert len(docs) == 2
|
||||
|
||||
|
||||
def test_load_converts_dataframe_columns_to_document_metadata(
|
||||
sample_data_frame: pl.DataFrame,
|
||||
) -> None:
|
||||
loader = PolarsDataFrameLoader(sample_data_frame)
|
||||
docs = loader.load()
|
||||
|
||||
for i, doc in enumerate(docs):
|
||||
df: pl.DataFrame = sample_data_frame[i]
|
||||
assert df is not None
|
||||
assert doc.metadata["author"] == df.select("author").item()
|
||||
assert doc.metadata["date"] == df.select("date").item()
|
||||
|
||||
|
||||
def test_load_uses_page_content_column_to_create_document_text(
|
||||
sample_data_frame: pl.DataFrame,
|
||||
) -> None:
|
||||
sample_data_frame = sample_data_frame.rename(mapping={"text": "dummy_test_column"})
|
||||
loader = PolarsDataFrameLoader(
|
||||
sample_data_frame, page_content_column="dummy_test_column"
|
||||
)
|
||||
docs = loader.load()
|
||||
assert docs[0].page_content == "Hello"
|
||||
assert docs[1].page_content == "World"
|
Loading…
Reference in New Issue
Block a user