You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/document_loaders/notebook.py

110 lines
3.4 KiB
Python

"""Loader that loads .ipynb notebook files."""
import json
from pathlib import Path
from typing import Any, List
import pandas as pd
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
def concatenate_cells(
cell: dict, include_outputs: bool, max_output_length: int, traceback: bool
) -> str:
"""Combine cells information in a readable format ready to be used."""
cell_type = cell["cell_type"]
source = cell["source"]
output = cell["outputs"]
if include_outputs and cell_type == "code" and output:
if "ename" in output[0].keys():
error_name = output[0]["ename"]
error_value = output[0]["evalue"]
if traceback:
traceback = output[0]["traceback"]
return (
f"'{cell_type}' cell: '{source}'\n, gives error '{error_name}',"
f" with description '{error_value}'\n"
f"and traceback '{traceback}'\n\n"
)
else:
return (
f"'{cell_type}' cell: '{source}'\n, gives error '{error_name}',"
f"with description '{error_value}'\n\n"
)
elif output[0]["output_type"] == "stream":
output = output[0]["text"]
min_output = min(max_output_length, len(output))
return (
f"'{cell_type}' cell: '{source}'\n with "
f"output: '{output[:min_output]}'\n\n"
)
else:
return f"'{cell_type}' cell: '{source}'\n\n"
return ""
def remove_newlines(x: Any) -> Any:
"""Remove recursivelly newlines, no matter the data structure they are stored in."""
if isinstance(x, str):
return x.replace("\n", "")
elif isinstance(x, list):
return [remove_newlines(elem) for elem in x]
elif isinstance(x, pd.DataFrame):
return x.applymap(remove_newlines)
else:
return x
class NotebookLoader(BaseLoader):
"""Loader that loads .ipynb notebook files."""
def __init__(
self,
path: str,
include_outputs: bool = False,
max_output_length: int = 10,
remove_newline: bool = False,
traceback: bool = False,
):
"""Initialize with path."""
self.file_path = path
self.include_outputs = include_outputs
self.max_output_length = max_output_length
self.remove_newline = remove_newline
self.traceback = traceback
def load(
self,
) -> List[Document]:
"""Load documents."""
try:
import pandas as pd
except ImportError:
raise ValueError(
"pandas is needed for Notebook Loader, "
"please install with `pip install pandas`"
)
p = Path(self.file_path)
with open(p, encoding="utf8") as f:
d = json.load(f)
data = pd.json_normalize(d["cells"])
filtered_data = data[["cell_type", "source", "outputs"]]
if self.remove_newline:
filtered_data = filtered_data.applymap(remove_newlines)
text = filtered_data.apply(
lambda x: concatenate_cells(
x, self.include_outputs, self.max_output_length, self.traceback
),
axis=1,
).str.cat(sep=" ")
metadata = {"source": str(p)}
return [Document(page_content=text, metadata=metadata)]