diff --git a/docs/extras/integrations/document_loaders/tensorflow_datasets.ipynb b/docs/extras/integrations/document_loaders/tensorflow_datasets.ipynb new file mode 100644 index 0000000000..39de886752 --- /dev/null +++ b/docs/extras/integrations/document_loaders/tensorflow_datasets.ipynb @@ -0,0 +1,320 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bda1f3f5", + "metadata": {}, + "source": [ + "# TensorFlow Datasets\n", + "\n", + ">[TensorFlow Datasets](https://www.tensorflow.org/datasets) is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as [tf.data.Datasets](https://www.tensorflow.org/api_docs/python/tf/data/Dataset), enabling easy-to-use and high-performance input pipelines. To get started see the [guide](https://www.tensorflow.org/datasets/overview) and the [list of datasets](https://www.tensorflow.org/datasets/catalog/overview#all_datasets).\n", + "\n", + "This notebook shows how to load `TensorFlow Datasets` into a Document format that we can use downstream." + ] + }, + { + "cell_type": "markdown", + "id": "1b7a1eef-7bf7-4e7d-8bfc-c4e27c9488cb", + "metadata": {}, + "source": [ + "## Installation" + ] + }, + { + "cell_type": "markdown", + "id": "2abd5578-aa3d-46b9-99af-8b262f0b3df8", + "metadata": {}, + "source": [ + "You need to install `tensorflow` and `tensorflow-datasets` python packages." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e589036-351e-4c63-b734-c9a05fadb880", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install tensorflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b674aaea-ed3a-4541-8414-260a8f67f623", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install tensorflow-datasets" + ] + }, + { + "cell_type": "markdown", + "id": "95f05e1c-195e-4e2b-ae8e-8d6637f15be6", + "metadata": {}, + "source": [ + "## Example" + ] + }, + { + "cell_type": "markdown", + "id": "e66e211e-9419-4dbb-b3cd-afc3cf984305", + "metadata": {}, + "source": [ + "As an example, we use the [`mlqa/en` dataset](https://www.tensorflow.org/datasets/catalog/mlqa#mlqaen).\n", + "\n", + ">`MLQA` (`Multilingual Question Answering Dataset`) is a benchmark dataset for evaluating multilingual question answering performance. The dataset consists of 7 languages: Arabic, German, Spanish, English, Hindi, Vietnamese, Chinese.\n", + ">\n", + ">- Homepage: https://github.com/facebookresearch/MLQA\n", + ">- Source code: `tfds.datasets.mlqa.Builder`\n", + ">- Download size: 72.21 MiB\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8968d645-c81c-4e3b-82bc-a3cbb5ddd93a", + "metadata": {}, + "outputs": [], + "source": [ + "# Feature structure of `mlqa/en` dataset:\n", + "\n", + "FeaturesDict({\n", + " 'answers': Sequence({\n", + " 'answer_start': int32,\n", + " 'text': Text(shape=(), dtype=string),\n", + " }),\n", + " 'context': Text(shape=(), dtype=string),\n", + " 'id': string,\n", + " 'question': Text(shape=(), dtype=string),\n", + " 'title': Text(shape=(), dtype=string),\n", + "})" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "30fcaba5-cc9b-4a0e-a8f4-c047018451c2", + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_datasets as tfds" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "e307dd67-029e-4ee3-a65f-e085c09b0b8b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<_TakeDataset element_spec={'answers': {'answer_start': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'text': TensorSpec(shape=(None,), dtype=tf.string, name=None)}, 'context': TensorSpec(shape=(), dtype=tf.string, name=None), 'id': TensorSpec(shape=(), dtype=tf.string, name=None), 'question': TensorSpec(shape=(), dtype=tf.string, name=None), 'title': TensorSpec(shape=(), dtype=tf.string, name=None)}>" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# try directly access this dataset:\n", + "ds = tfds.load('mlqa/en', split='test')\n", + "ds = ds.take(1) # Only take a single example\n", + "ds" + ] + }, + { + "cell_type": "markdown", + "id": "5c9c4b08-d94f-4b53-add0-93769811644e", + "metadata": {}, + "source": [ + "Now we have to create a custom function to convert dataset sample into a Document.\n", + "\n", + "This is a requirement. There is no standard format for the TF datasets that's why we need to make a custom transformation function.\n", + "\n", + "Let's use `context` field as the `Document.page_content` and place other fields in the `Document.metadata`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "78844113-f8d8-48a8-8105-685280b6cfa5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "page_content='After completing the journey around South America, on 23 February 2006, Queen Mary 2 met her namesake, the original RMS Queen Mary, which is permanently docked at Long Beach, California. Escorted by a flotilla of smaller ships, the two Queens exchanged a \"whistle salute\" which was heard throughout the city of Long Beach. Queen Mary 2 met the other serving Cunard liners Queen Victoria and Queen Elizabeth 2 on 13 January 2008 near the Statue of Liberty in New York City harbour, with a celebratory fireworks display; Queen Elizabeth 2 and Queen Victoria made a tandem crossing of the Atlantic for the meeting. This marked the first time three Cunard Queens have been present in the same location. Cunard stated this would be the last time these three ships would ever meet, due to Queen Elizabeth 2\\'s impending retirement from service in late 2008. However this would prove not to be the case, as the three Queens met in Southampton on 22 April 2008. Queen Mary 2 rendezvoused with Queen Elizabeth 2 in Dubai on Saturday 21 March 2009, after the latter ship\\'s retirement, while both ships were berthed at Port Rashid. With the withdrawal of Queen Elizabeth 2 from Cunard\\'s fleet and its docking in Dubai, Queen Mary 2 became the only ocean liner left in active passenger service.' metadata={'id': '5116f7cccdbf614d60bcd23498274ffd7b1e4ec7', 'title': 'RMS Queen Mary 2', 'question': 'What year did Queen Mary 2 complete her journey around South America?', 'answer': '2006'}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-03 14:27:08.482983: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" + ] + } + ], + "source": [ + "def decode_to_str(item: tf.Tensor) -> str:\n", + " return item.numpy().decode('utf-8')\n", + "\n", + "def mlqaen_example_to_document(example: dict) -> Document:\n", + " return Document(\n", + " page_content=decode_to_str(example[\"context\"]),\n", + " metadata={\n", + " \"id\": decode_to_str(example[\"id\"]),\n", + " \"title\": decode_to_str(example[\"title\"]),\n", + " \"question\": decode_to_str(example[\"question\"]),\n", + " \"answer\": decode_to_str(example[\"answers\"][\"text\"][0]),\n", + " },\n", + " )\n", + " \n", + " \n", + "for example in ds: \n", + " doc = mlqaen_example_to_document(example)\n", + " print(doc)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "2d43c834-5145-4793-9558-8e301ccaf3b4", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema import Document\n", + "from langchain.document_loaders import TensorflowDatasetLoader\n", + "\n", + "loader = TensorflowDatasetLoader(\n", + " dataset_name=\"mlqa/en\",\n", + " split_name=\"test\",\n", + " load_max_docs=3,\n", + " sample_to_document_function=mlqaen_example_to_document,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "e29b954c-1407-4797-ae21-6ba8937156be", + "metadata": {}, + "source": [ + "`TensorflowDatasetLoader` has these parameters:\n", + "- `dataset_name`: the name of the dataset to load\n", + "- `split_name`: the name of the split to load. Defaults to \"train\".\n", + "- `load_max_docs`: a limit to the number of loaded documents. Defaults to 100.\n", + "- `sample_to_document_function`: a function that converts a dataset sample to a Document\n" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "700e4ef2", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-03 14:27:22.998964: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" + ] + }, + { + "data": { + "text/plain": [ + "3" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "docs = loader.load()\n", + "len(docs)" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "9138940a-e9fe-4145-83e8-77589b5272c9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'After completing the journey around South America, on 23 February 2006, Queen Mary 2 met her namesake, the original RMS Queen Mary, which is permanently docked at Long Beach, California. Escorted by a flotilla of smaller ships, the two Queens exchanged a \"whistle salute\" which was heard throughout the city of Long Beach. Queen Mary 2 met the other serving Cunard liners Queen Victoria and Queen Elizabeth 2 on 13 January 2008 near the Statue of Liberty in New York City harbour, with a celebratory fireworks display; Queen Elizabeth 2 and Queen Victoria made a tandem crossing of the Atlantic for the meeting. This marked the first time three Cunard Queens have been present in the same location. Cunard stated this would be the last time these three ships would ever meet, due to Queen Elizabeth 2\\'s impending retirement from service in late 2008. However this would prove not to be the case, as the three Queens met in Southampton on 22 April 2008. Queen Mary 2 rendezvoused with Queen Elizabeth 2 in Dubai on Saturday 21 March 2009, after the latter ship\\'s retirement, while both ships were berthed at Port Rashid. With the withdrawal of Queen Elizabeth 2 from Cunard\\'s fleet and its docking in Dubai, Queen Mary 2 became the only ocean liner left in active passenger service.'" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "docs[0].page_content" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "2f7f7832-fe4d-4a58-892d-bb987cdbed0b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'id': '5116f7cccdbf614d60bcd23498274ffd7b1e4ec7',\n", + " 'title': 'RMS Queen Mary 2',\n", + " 'question': 'What year did Queen Mary 2 complete her journey around South America?',\n", + " 'answer': '2006'}" + ] + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "docs[0].metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "125d073c-4f4f-4ae6-a0c7-9e9db3cc8d69", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/extras/integrations/providers/tensorflow_datasets.mdx b/docs/extras/integrations/providers/tensorflow_datasets.mdx new file mode 100644 index 0000000000..6b77756344 --- /dev/null +++ b/docs/extras/integrations/providers/tensorflow_datasets.mdx @@ -0,0 +1,31 @@ +# TensorFlow Datasets + +>[TensorFlow Datasets](https://www.tensorflow.org/datasets) is a collection of datasets ready to use, +> with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed +> as [tf.data.Datasets](https://www.tensorflow.org/api_docs/python/tf/data/Dataset), +> enabling easy-to-use and high-performance input pipelines. To get started see +> the [guide](https://www.tensorflow.org/datasets/overview) and +> the [list of datasets](https://www.tensorflow.org/datasets/catalog/overview#all_datasets). + + + +## Installation and Setup + +You need to install `tensorflow` and `tensorflow-datasets` python packages. + +```bash +pip install tensorflow +``` + +```bash +pip install tensorflow-dataset +``` + + +## Document Loader + +See a [usage example](/docs/integrations/document_loaders/tensorflow_datasets). + +```python +from langchain.document_loaders import TensorflowDatasetLoader +``` diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index d0a2a0211c..b52c0927db 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -147,6 +147,7 @@ from langchain.document_loaders.telegram import ( ) from langchain.document_loaders.tencent_cos_directory import TencentCOSDirectoryLoader from langchain.document_loaders.tencent_cos_file import TencentCOSFileLoader +from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader from langchain.document_loaders.text import TextLoader from langchain.document_loaders.tomarkdown import ToMarkdownLoader from langchain.document_loaders.toml import TomlLoader @@ -299,6 +300,7 @@ __all__ = [ "TelegramChatApiLoader", "TelegramChatFileLoader", "TelegramChatLoader", + "TensorflowDatasetLoader", "TencentCOSDirectoryLoader", "TencentCOSFileLoader", "TextLoader", diff --git a/libs/langchain/langchain/document_loaders/arxiv.py b/libs/langchain/langchain/document_loaders/arxiv.py index d0ae945238..6a7e139ca4 100644 --- a/libs/langchain/langchain/document_loaders/arxiv.py +++ b/libs/langchain/langchain/document_loaders/arxiv.py @@ -8,7 +8,6 @@ from langchain.utilities.arxiv import ArxivAPIWrapper class ArxivLoader(BaseLoader): """Loads a query result from arxiv.org into a list of Documents. - Each document represents one Document. The loader converts the original PDF format into the text. """ diff --git a/libs/langchain/langchain/document_loaders/tensorflow_datasets.py b/libs/langchain/langchain/document_loaders/tensorflow_datasets.py new file mode 100644 index 0000000000..e908aac873 --- /dev/null +++ b/libs/langchain/langchain/document_loaders/tensorflow_datasets.py @@ -0,0 +1,79 @@ +from typing import Callable, Dict, Iterator, List, Optional + +from langchain.document_loaders.base import BaseLoader +from langchain.schema import Document +from langchain.utilities.tensorflow_datasets import TensorflowDatasets + + +class TensorflowDatasetLoader(BaseLoader): + """Loads from TensorFlow Datasets into a list of Documents. + + Attributes: + dataset_name: the name of the dataset to load + split_name: the name of the split to load. + load_max_docs: a limit to the number of loaded documents. Defaults to 100. + sample_to_document_function: a function that converts a dataset sample + into a Document + + Example: + .. code-block:: python + + from langchain.document_loaders import TensorflowDatasetLoader + + def mlqaen_example_to_document(example: dict) -> Document: + return Document( + page_content=decode_to_str(example["context"]), + metadata={ + "id": decode_to_str(example["id"]), + "title": decode_to_str(example["title"]), + "question": decode_to_str(example["question"]), + "answer": decode_to_str(example["answers"]["text"][0]), + }, + ) + + tsds_client = TensorflowDatasetLoader( + dataset_name="mlqa/en", + split_name="test", + load_max_docs=100, + sample_to_document_function=mlqaen_example_to_document, + ) + + """ + + def __init__( + self, + dataset_name: str, + split_name: str, + load_max_docs: Optional[int] = 100, + sample_to_document_function: Optional[Callable[[Dict], Document]] = None, + ): + """Initialize the TensorflowDatasetLoader. + + Args: + dataset_name: the name of the dataset to load + split_name: the name of the split to load. + load_max_docs: a limit to the number of loaded documents. Defaults to 100. + sample_to_document_function: a function that converts a dataset sample + into a Document. + """ + self.dataset_name: str = dataset_name + self.split_name: str = split_name + self.load_max_docs = load_max_docs + """The maximum number of documents to load.""" + self.sample_to_document_function: Optional[ + Callable[[Dict], Document] + ] = sample_to_document_function + """Custom function that transform a dataset sample into a Document.""" + + self._tfds_client = TensorflowDatasets( + dataset_name=self.dataset_name, + split_name=self.split_name, + load_max_docs=self.load_max_docs, + sample_to_document_function=self.sample_to_document_function, + ) + + def lazy_load(self) -> Iterator[Document]: + yield from self._tfds_client.lazy_load() + + def load(self) -> List[Document]: + return list(self.lazy_load()) diff --git a/libs/langchain/langchain/utilities/__init__.py b/libs/langchain/langchain/utilities/__init__.py index b747ac9df5..167ff1f8fd 100644 --- a/libs/langchain/langchain/utilities/__init__.py +++ b/libs/langchain/langchain/utilities/__init__.py @@ -29,6 +29,7 @@ from langchain.utilities.searx_search import SearxSearchWrapper from langchain.utilities.serpapi import SerpAPIWrapper from langchain.utilities.spark_sql import SparkSQL from langchain.utilities.sql_database import SQLDatabase +from langchain.utilities.tensorflow_datasets import TensorflowDatasets from langchain.utilities.twilio import TwilioAPIWrapper from langchain.utilities.wikipedia import WikipediaAPIWrapper from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper @@ -62,6 +63,7 @@ __all__ = [ "SearxSearchWrapper", "SerpAPIWrapper", "SparkSQL", + "TensorflowDatasets", "TextRequestsWrapper", "TextRequestsWrapper", "TwilioAPIWrapper", diff --git a/libs/langchain/langchain/utilities/arxiv.py b/libs/langchain/langchain/utilities/arxiv.py index 2ad42daba8..246ed89045 100644 --- a/libs/langchain/langchain/utilities/arxiv.py +++ b/libs/langchain/langchain/utilities/arxiv.py @@ -21,7 +21,7 @@ class ArxivAPIWrapper(BaseModel): It limits the Document content by doc_content_chars_max. Set doc_content_chars_max=None if you don't want to limit the content size. - Args: + Attributes: top_k_results: number of the top-scored document used for the arxiv tool ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool. load_max_docs: a limit to the number of loaded documents diff --git a/libs/langchain/langchain/utilities/tensorflow_datasets.py b/libs/langchain/langchain/utilities/tensorflow_datasets.py new file mode 100644 index 0000000000..1a7073d2f7 --- /dev/null +++ b/libs/langchain/langchain/utilities/tensorflow_datasets.py @@ -0,0 +1,111 @@ +import logging +from typing import Any, Callable, Dict, Iterator, List, Optional + +from pydantic import BaseModel, root_validator + +from langchain.schema import Document + +logger = logging.getLogger(__name__) + + +class TensorflowDatasets(BaseModel): + """Access to the TensorFlow Datasets. + + The Current implementation can work only with datasets that fit in a memory. + + `TensorFlow Datasets` is a collection of datasets ready to use, with TensorFlow + or other Python ML frameworks, such as Jax. All datasets are exposed + as `tf.data.Datasets`. + To get started see the Guide: https://www.tensorflow.org/datasets/overview and + the list of datasets: https://www.tensorflow.org/datasets/catalog/ + overview#all_datasets + + You have to provide the sample_to_document_function: a function that + a sample from the dataset-specific format to the Document. + + Attributes: + dataset_name: the name of the dataset to load + split_name: the name of the split to load. Defaults to "train". + load_max_docs: a limit to the number of loaded documents. Defaults to 100. + sample_to_document_function: a function that converts a dataset sample + to a Document + + Example: + .. code-block:: python + + from langchain.utilities import TensorflowDatasets + + def mlqaen_example_to_document(example: dict) -> Document: + return Document( + page_content=decode_to_str(example["context"]), + metadata={ + "id": decode_to_str(example["id"]), + "title": decode_to_str(example["title"]), + "question": decode_to_str(example["question"]), + "answer": decode_to_str(example["answers"]["text"][0]), + }, + ) + + tsds_client = TensorflowDatasets( + dataset_name="mlqa/en", + split_name="train", + load_max_docs=MAX_DOCS, + sample_to_document_function=mlqaen_example_to_document, + ) + + """ + + dataset_name: str = "" + split_name: str = "train" + load_max_docs: int = 100 + sample_to_document_function: Optional[Callable[[Dict], Document]] = None + dataset: Any #: :meta private: + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that the python package exists in environment.""" + try: + import tensorflow # noqa: F401 + except ImportError: + raise ImportError( + "Could not import tensorflow python package. " + "Please install it with `pip install tensorflow`." + ) + try: + import tensorflow_datasets + except ImportError: + raise ImportError( + "Could not import tensorflow_datasets python package. " + "Please install it with `pip install tensorflow-datasets`." + ) + if values["sample_to_document_function"] is None: + raise ValueError( + "sample_to_document_function is None. " + "Please provide a function that converts a dataset sample to" + " a Document." + ) + values["dataset"] = tensorflow_datasets.load( + values["dataset_name"], split=values["split_name"] + ) + + return values + + def lazy_load(self) -> Iterator[Document]: + """Download a selected dataset lazily. + + Returns: an iterator of Documents. + + """ + return ( + self.sample_to_document_function(s) + for s in self.dataset.take(self.load_max_docs) + if self.sample_to_document_function is not None + ) + + def load(self) -> List[Document]: + """Download a selected dataset. + + Returns: a list of Documents. + + """ + return list(self.lazy_load()) diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_tensorflow_datasets.py b/libs/langchain/tests/integration_tests/document_loaders/test_tensorflow_datasets.py new file mode 100644 index 0000000000..50a5f7e96f --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_tensorflow_datasets.py @@ -0,0 +1,105 @@ +"""Integration tests for the TensorFlow Dataset Loader.""" + +import pytest +from pydantic.error_wrappers import ValidationError + +from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader +from langchain.schema.document import Document + +# adding tensorflow and tensorflow_datasets to pyproject.toml is not working +# these tests can be run in isolation only +tensorflow = pytest.importorskip("tensorflow") +tensorflow_datasets = pytest.importorskip("tensorflow_datasets") + +# placed here after checking for tensorflow package installation +import tensorflow as tf # noqa: E402 + + +def decode_to_str(item: tf.Tensor) -> str: + return item.numpy().decode("utf-8") + + +def mlqaen_example_to_document(example: dict) -> Document: + return Document( + page_content=decode_to_str(example["context"]), + metadata={ + "id": decode_to_str(example["id"]), + "title": decode_to_str(example["title"]), + "question": decode_to_str(example["question"]), + "answer": decode_to_str(example["answers"]["text"][0]), + }, + ) + + +MAX_DOCS = 10 + + +@pytest.fixture +def tfds_client() -> TensorflowDatasetLoader: + return TensorflowDatasetLoader( + dataset_name="mlqa/en", + split_name="test", + load_max_docs=MAX_DOCS, + sample_to_document_function=mlqaen_example_to_document, + ) + + +def test_load_success(tfds_client: TensorflowDatasetLoader) -> None: + """Test that returns the correct answer""" + + output = tfds_client.load() + assert isinstance(output, list) + assert len(output) == MAX_DOCS + + assert isinstance(output[0], Document) + assert len(output[0].page_content) > 0 + assert isinstance(output[0].page_content, str) + assert isinstance(output[0].metadata, dict) + + +def test_lazy_load_success(tfds_client: TensorflowDatasetLoader) -> None: + """Test that returns the correct answer""" + + output = list(tfds_client.lazy_load()) + assert isinstance(output, list) + assert len(output) == MAX_DOCS + + assert isinstance(output[0], Document) + assert len(output[0].page_content) > 0 + assert isinstance(output[0].page_content, str) + assert isinstance(output[0].metadata, dict) + + +def test_load_fail_wrong_dataset_name() -> None: + """Test that fails to load""" + with pytest.raises(ValidationError) as exc_info: + TensorflowDatasetLoader( + dataset_name="wrong_dataset_name", + split_name="test", + load_max_docs=MAX_DOCS, + sample_to_document_function=mlqaen_example_to_document, + ) + assert "the dataset name is spelled correctly" in str(exc_info.value) + + +def test_load_fail_wrong_split_name() -> None: + """Test that fails to load""" + with pytest.raises(ValidationError) as exc_info: + TensorflowDatasetLoader( + dataset_name="mlqa/en", + split_name="wrong_split_name", + load_max_docs=MAX_DOCS, + sample_to_document_function=mlqaen_example_to_document, + ) + assert "Unknown split" in str(exc_info.value) + + +def test_load_fail_no_func() -> None: + """Test that fails to load""" + with pytest.raises(ValidationError) as exc_info: + TensorflowDatasetLoader( + dataset_name="mlqa/en", + split_name="test", + load_max_docs=MAX_DOCS, + ) + assert "Please provide a function" in str(exc_info.value) diff --git a/libs/langchain/tests/integration_tests/utilities/test_tensorflow_datasets.py b/libs/langchain/tests/integration_tests/utilities/test_tensorflow_datasets.py new file mode 100644 index 0000000000..de2d572b18 --- /dev/null +++ b/libs/langchain/tests/integration_tests/utilities/test_tensorflow_datasets.py @@ -0,0 +1,90 @@ +"""Integration tests for the TensorFlow Dataset client.""" + +import pytest +import tensorflow as tf +from pydantic.error_wrappers import ValidationError + +from langchain.schema.document import Document +from langchain.utilities.tensorflow_datasets import TensorflowDatasets + +# adding tensorflow and tensorflow_datasets to pyproject.toml is not working +# these tests can be tested in isolation only +tensorflow = pytest.importorskip("tensorflow") +tensorflow_datasets = pytest.importorskip("tensorflow_datasets") + + +def decode_to_str(item: tf.Tensor) -> str: + return item.numpy().decode("utf-8") + + +def mlqaen_example_to_document(example: dict) -> Document: + return Document( + page_content=decode_to_str(example["context"]), + metadata={ + "id": decode_to_str(example["id"]), + "title": decode_to_str(example["title"]), + "question": decode_to_str(example["question"]), + "answer": decode_to_str(example["answers"]["text"][0]), + }, + ) + + +MAX_DOCS = 10 + + +@pytest.fixture +def tfds_client() -> TensorflowDatasets: + return TensorflowDatasets( + dataset_name="mlqa/en", + split_name="test", + load_max_docs=MAX_DOCS, + sample_to_document_function=mlqaen_example_to_document, + ) + + +def test_load_success(tfds_client: TensorflowDatasets) -> None: + """Test that returns the correct answer""" + + output = tfds_client.load() + assert isinstance(output, list) + assert len(output) == MAX_DOCS + + assert isinstance(output[0], Document) + assert len(output[0].page_content) > 0 + assert isinstance(output[0].page_content, str) + assert isinstance(output[0].metadata, dict) + + +def test_load_fail_wrong_dataset_name() -> None: + """Test that fails to load""" + with pytest.raises(ValidationError) as exc_info: + TensorflowDatasets( + dataset_name="wrong_dataset_name", + split_name="test", + load_max_docs=MAX_DOCS, + sample_to_document_function=mlqaen_example_to_document, + ) + assert "the dataset name is spelled correctly" in str(exc_info.value) + + +def test_load_fail_wrong_split_name() -> None: + """Test that fails to load""" + with pytest.raises(ValidationError) as exc_info: + TensorflowDatasets( + dataset_name="mlqa/en", + split_name="wrong_split_name", + load_max_docs=MAX_DOCS, + sample_to_document_function=mlqaen_example_to_document, + ) + assert "Unknown split" in str(exc_info.value) + + +def test_load_fail_no_func() -> None: + """Test that fails to load""" + with pytest.raises(ValidationError) as exc_info: + TensorflowDatasets( + dataset_name="mlqa/en", + split_name="test", + load_max_docs=MAX_DOCS, + ) + assert "Please provide a function" in str(exc_info.value)