from typing import Callable, Dict, Iterator, Optional from langchain_core.documents import Document from langchain_community.document_loaders.base import BaseLoader from langchain_community.utilities.tensorflow_datasets import TensorflowDatasets class TensorflowDatasetLoader(BaseLoader): """Load from `TensorFlow Dataset`. Attributes: dataset_name: the name of the dataset to load split_name: the name of the split to load. load_max_docs: a limit to the number of loaded documents. Defaults to 100. sample_to_document_function: a function that converts a dataset sample into a Document Example: .. code-block:: python from langchain_community.document_loaders import TensorflowDatasetLoader def mlqaen_example_to_document(example: dict) -> Document: return Document( page_content=decode_to_str(example["context"]), metadata={ "id": decode_to_str(example["id"]), "title": decode_to_str(example["title"]), "question": decode_to_str(example["question"]), "answer": decode_to_str(example["answers"]["text"][0]), }, ) tsds_client = TensorflowDatasetLoader( dataset_name="mlqa/en", split_name="test", load_max_docs=100, sample_to_document_function=mlqaen_example_to_document, ) """ def __init__( self, dataset_name: str, split_name: str, load_max_docs: Optional[int] = 100, sample_to_document_function: Optional[Callable[[Dict], Document]] = None, ): """Initialize the TensorflowDatasetLoader. Args: dataset_name: the name of the dataset to load split_name: the name of the split to load. load_max_docs: a limit to the number of loaded documents. Defaults to 100. sample_to_document_function: a function that converts a dataset sample into a Document. """ self.dataset_name: str = dataset_name self.split_name: str = split_name self.load_max_docs = load_max_docs """The maximum number of documents to load.""" self.sample_to_document_function: Optional[Callable[[Dict], Document]] = ( sample_to_document_function ) """Custom function that transform a dataset sample into a Document.""" self._tfds_client = TensorflowDatasets( # type: ignore[call-arg] dataset_name=self.dataset_name, split_name=self.split_name, load_max_docs=self.load_max_docs, # type: ignore[arg-type] sample_to_document_function=self.sample_to_document_function, ) def lazy_load(self) -> Iterator[Document]: yield from self._tfds_client.lazy_load()