`tensoflow_datasets` document loader (#8721)

This PR adds `tensoflow_datasets` document loader
pull/845/head
Leonid Ganeline 1 year ago committed by GitHub
parent fad26e79a3
commit 33a2f58fbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,320 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "bda1f3f5",
"metadata": {},
"source": [
"# TensorFlow Datasets\n",
"\n",
">[TensorFlow Datasets](https://www.tensorflow.org/datasets) is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as [tf.data.Datasets](https://www.tensorflow.org/api_docs/python/tf/data/Dataset), enabling easy-to-use and high-performance input pipelines. To get started see the [guide](https://www.tensorflow.org/datasets/overview) and the [list of datasets](https://www.tensorflow.org/datasets/catalog/overview#all_datasets).\n",
"\n",
"This notebook shows how to load `TensorFlow Datasets` into a Document format that we can use downstream."
]
},
{
"cell_type": "markdown",
"id": "1b7a1eef-7bf7-4e7d-8bfc-c4e27c9488cb",
"metadata": {},
"source": [
"## Installation"
]
},
{
"cell_type": "markdown",
"id": "2abd5578-aa3d-46b9-99af-8b262f0b3df8",
"metadata": {},
"source": [
"You need to install `tensorflow` and `tensorflow-datasets` python packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e589036-351e-4c63-b734-c9a05fadb880",
"metadata": {},
"outputs": [],
"source": [
"!pip install tensorflow"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b674aaea-ed3a-4541-8414-260a8f67f623",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install tensorflow-datasets"
]
},
{
"cell_type": "markdown",
"id": "95f05e1c-195e-4e2b-ae8e-8d6637f15be6",
"metadata": {},
"source": [
"## Example"
]
},
{
"cell_type": "markdown",
"id": "e66e211e-9419-4dbb-b3cd-afc3cf984305",
"metadata": {},
"source": [
"As an example, we use the [`mlqa/en` dataset](https://www.tensorflow.org/datasets/catalog/mlqa#mlqaen).\n",
"\n",
">`MLQA` (`Multilingual Question Answering Dataset`) is a benchmark dataset for evaluating multilingual question answering performance. The dataset consists of 7 languages: Arabic, German, Spanish, English, Hindi, Vietnamese, Chinese.\n",
">\n",
">- Homepage: https://github.com/facebookresearch/MLQA\n",
">- Source code: `tfds.datasets.mlqa.Builder`\n",
">- Download size: 72.21 MiB\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8968d645-c81c-4e3b-82bc-a3cbb5ddd93a",
"metadata": {},
"outputs": [],
"source": [
"# Feature structure of `mlqa/en` dataset:\n",
"\n",
"FeaturesDict({\n",
" 'answers': Sequence({\n",
" 'answer_start': int32,\n",
" 'text': Text(shape=(), dtype=string),\n",
" }),\n",
" 'context': Text(shape=(), dtype=string),\n",
" 'id': string,\n",
" 'question': Text(shape=(), dtype=string),\n",
" 'title': Text(shape=(), dtype=string),\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "30fcaba5-cc9b-4a0e-a8f4-c047018451c2",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "e307dd67-029e-4ee3-a65f-e085c09b0b8b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<_TakeDataset element_spec={'answers': {'answer_start': TensorSpec(shape=(None,), dtype=tf.int32, name=None), 'text': TensorSpec(shape=(None,), dtype=tf.string, name=None)}, 'context': TensorSpec(shape=(), dtype=tf.string, name=None), 'id': TensorSpec(shape=(), dtype=tf.string, name=None), 'question': TensorSpec(shape=(), dtype=tf.string, name=None), 'title': TensorSpec(shape=(), dtype=tf.string, name=None)}>"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# try directly access this dataset:\n",
"ds = tfds.load('mlqa/en', split='test')\n",
"ds = ds.take(1) # Only take a single example\n",
"ds"
]
},
{
"cell_type": "markdown",
"id": "5c9c4b08-d94f-4b53-add0-93769811644e",
"metadata": {},
"source": [
"Now we have to create a custom function to convert dataset sample into a Document.\n",
"\n",
"This is a requirement. There is no standard format for the TF datasets that's why we need to make a custom transformation function.\n",
"\n",
"Let's use `context` field as the `Document.page_content` and place other fields in the `Document.metadata`.\n"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "78844113-f8d8-48a8-8105-685280b6cfa5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"page_content='After completing the journey around South America, on 23 February 2006, Queen Mary 2 met her namesake, the original RMS Queen Mary, which is permanently docked at Long Beach, California. Escorted by a flotilla of smaller ships, the two Queens exchanged a \"whistle salute\" which was heard throughout the city of Long Beach. Queen Mary 2 met the other serving Cunard liners Queen Victoria and Queen Elizabeth 2 on 13 January 2008 near the Statue of Liberty in New York City harbour, with a celebratory fireworks display; Queen Elizabeth 2 and Queen Victoria made a tandem crossing of the Atlantic for the meeting. This marked the first time three Cunard Queens have been present in the same location. Cunard stated this would be the last time these three ships would ever meet, due to Queen Elizabeth 2\\'s impending retirement from service in late 2008. However this would prove not to be the case, as the three Queens met in Southampton on 22 April 2008. Queen Mary 2 rendezvoused with Queen Elizabeth 2 in Dubai on Saturday 21 March 2009, after the latter ship\\'s retirement, while both ships were berthed at Port Rashid. With the withdrawal of Queen Elizabeth 2 from Cunard\\'s fleet and its docking in Dubai, Queen Mary 2 became the only ocean liner left in active passenger service.' metadata={'id': '5116f7cccdbf614d60bcd23498274ffd7b1e4ec7', 'title': 'RMS Queen Mary 2', 'question': 'What year did Queen Mary 2 complete her journey around South America?', 'answer': '2006'}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-08-03 14:27:08.482983: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n"
]
}
],
"source": [
"def decode_to_str(item: tf.Tensor) -> str:\n",
" return item.numpy().decode('utf-8')\n",
"\n",
"def mlqaen_example_to_document(example: dict) -> Document:\n",
" return Document(\n",
" page_content=decode_to_str(example[\"context\"]),\n",
" metadata={\n",
" \"id\": decode_to_str(example[\"id\"]),\n",
" \"title\": decode_to_str(example[\"title\"]),\n",
" \"question\": decode_to_str(example[\"question\"]),\n",
" \"answer\": decode_to_str(example[\"answers\"][\"text\"][0]),\n",
" },\n",
" )\n",
" \n",
" \n",
"for example in ds: \n",
" doc = mlqaen_example_to_document(example)\n",
" print(doc)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "2d43c834-5145-4793-9558-8e301ccaf3b4",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema import Document\n",
"from langchain.document_loaders import TensorflowDatasetLoader\n",
"\n",
"loader = TensorflowDatasetLoader(\n",
" dataset_name=\"mlqa/en\",\n",
" split_name=\"test\",\n",
" load_max_docs=3,\n",
" sample_to_document_function=mlqaen_example_to_document,\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "e29b954c-1407-4797-ae21-6ba8937156be",
"metadata": {},
"source": [
"`TensorflowDatasetLoader` has these parameters:\n",
"- `dataset_name`: the name of the dataset to load\n",
"- `split_name`: the name of the split to load. Defaults to \"train\".\n",
"- `load_max_docs`: a limit to the number of loaded documents. Defaults to 100.\n",
"- `sample_to_document_function`: a function that converts a dataset sample to a Document\n"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "700e4ef2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-08-03 14:27:22.998964: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n"
]
},
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"docs = loader.load()\n",
"len(docs)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "9138940a-e9fe-4145-83e8-77589b5272c9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'After completing the journey around South America, on 23 February 2006, Queen Mary 2 met her namesake, the original RMS Queen Mary, which is permanently docked at Long Beach, California. Escorted by a flotilla of smaller ships, the two Queens exchanged a \"whistle salute\" which was heard throughout the city of Long Beach. Queen Mary 2 met the other serving Cunard liners Queen Victoria and Queen Elizabeth 2 on 13 January 2008 near the Statue of Liberty in New York City harbour, with a celebratory fireworks display; Queen Elizabeth 2 and Queen Victoria made a tandem crossing of the Atlantic for the meeting. This marked the first time three Cunard Queens have been present in the same location. Cunard stated this would be the last time these three ships would ever meet, due to Queen Elizabeth 2\\'s impending retirement from service in late 2008. However this would prove not to be the case, as the three Queens met in Southampton on 22 April 2008. Queen Mary 2 rendezvoused with Queen Elizabeth 2 in Dubai on Saturday 21 March 2009, after the latter ship\\'s retirement, while both ships were berthed at Port Rashid. With the withdrawal of Queen Elizabeth 2 from Cunard\\'s fleet and its docking in Dubai, Queen Mary 2 became the only ocean liner left in active passenger service.'"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"docs[0].page_content"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "2f7f7832-fe4d-4a58-892d-bb987cdbed0b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'id': '5116f7cccdbf614d60bcd23498274ffd7b1e4ec7',\n",
" 'title': 'RMS Queen Mary 2',\n",
" 'question': 'What year did Queen Mary 2 complete her journey around South America?',\n",
" 'answer': '2006'}"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"docs[0].metadata"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "125d073c-4f4f-4ae6-a0c7-9e9db3cc8d69",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -0,0 +1,31 @@
# TensorFlow Datasets
>[TensorFlow Datasets](https://www.tensorflow.org/datasets) is a collection of datasets ready to use,
> with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed
> as [tf.data.Datasets](https://www.tensorflow.org/api_docs/python/tf/data/Dataset),
> enabling easy-to-use and high-performance input pipelines. To get started see
> the [guide](https://www.tensorflow.org/datasets/overview) and
> the [list of datasets](https://www.tensorflow.org/datasets/catalog/overview#all_datasets).
## Installation and Setup
You need to install `tensorflow` and `tensorflow-datasets` python packages.
```bash
pip install tensorflow
```
```bash
pip install tensorflow-dataset
```
## Document Loader
See a [usage example](/docs/integrations/document_loaders/tensorflow_datasets).
```python
from langchain.document_loaders import TensorflowDatasetLoader
```

@ -147,6 +147,7 @@ from langchain.document_loaders.telegram import (
) )
from langchain.document_loaders.tencent_cos_directory import TencentCOSDirectoryLoader from langchain.document_loaders.tencent_cos_directory import TencentCOSDirectoryLoader
from langchain.document_loaders.tencent_cos_file import TencentCOSFileLoader from langchain.document_loaders.tencent_cos_file import TencentCOSFileLoader
from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader
from langchain.document_loaders.text import TextLoader from langchain.document_loaders.text import TextLoader
from langchain.document_loaders.tomarkdown import ToMarkdownLoader from langchain.document_loaders.tomarkdown import ToMarkdownLoader
from langchain.document_loaders.toml import TomlLoader from langchain.document_loaders.toml import TomlLoader
@ -299,6 +300,7 @@ __all__ = [
"TelegramChatApiLoader", "TelegramChatApiLoader",
"TelegramChatFileLoader", "TelegramChatFileLoader",
"TelegramChatLoader", "TelegramChatLoader",
"TensorflowDatasetLoader",
"TencentCOSDirectoryLoader", "TencentCOSDirectoryLoader",
"TencentCOSFileLoader", "TencentCOSFileLoader",
"TextLoader", "TextLoader",

@ -8,7 +8,6 @@ from langchain.utilities.arxiv import ArxivAPIWrapper
class ArxivLoader(BaseLoader): class ArxivLoader(BaseLoader):
"""Loads a query result from arxiv.org into a list of Documents. """Loads a query result from arxiv.org into a list of Documents.
Each document represents one Document.
The loader converts the original PDF format into the text. The loader converts the original PDF format into the text.
""" """

@ -0,0 +1,79 @@
from typing import Callable, Dict, Iterator, List, Optional
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from langchain.utilities.tensorflow_datasets import TensorflowDatasets
class TensorflowDatasetLoader(BaseLoader):
"""Loads from TensorFlow Datasets into a list of Documents.
Attributes:
dataset_name: the name of the dataset to load
split_name: the name of the split to load.
load_max_docs: a limit to the number of loaded documents. Defaults to 100.
sample_to_document_function: a function that converts a dataset sample
into a Document
Example:
.. code-block:: python
from langchain.document_loaders import TensorflowDatasetLoader
def mlqaen_example_to_document(example: dict) -> Document:
return Document(
page_content=decode_to_str(example["context"]),
metadata={
"id": decode_to_str(example["id"]),
"title": decode_to_str(example["title"]),
"question": decode_to_str(example["question"]),
"answer": decode_to_str(example["answers"]["text"][0]),
},
)
tsds_client = TensorflowDatasetLoader(
dataset_name="mlqa/en",
split_name="test",
load_max_docs=100,
sample_to_document_function=mlqaen_example_to_document,
)
"""
def __init__(
self,
dataset_name: str,
split_name: str,
load_max_docs: Optional[int] = 100,
sample_to_document_function: Optional[Callable[[Dict], Document]] = None,
):
"""Initialize the TensorflowDatasetLoader.
Args:
dataset_name: the name of the dataset to load
split_name: the name of the split to load.
load_max_docs: a limit to the number of loaded documents. Defaults to 100.
sample_to_document_function: a function that converts a dataset sample
into a Document.
"""
self.dataset_name: str = dataset_name
self.split_name: str = split_name
self.load_max_docs = load_max_docs
"""The maximum number of documents to load."""
self.sample_to_document_function: Optional[
Callable[[Dict], Document]
] = sample_to_document_function
"""Custom function that transform a dataset sample into a Document."""
self._tfds_client = TensorflowDatasets(
dataset_name=self.dataset_name,
split_name=self.split_name,
load_max_docs=self.load_max_docs,
sample_to_document_function=self.sample_to_document_function,
)
def lazy_load(self) -> Iterator[Document]:
yield from self._tfds_client.lazy_load()
def load(self) -> List[Document]:
return list(self.lazy_load())

@ -29,6 +29,7 @@ from langchain.utilities.searx_search import SearxSearchWrapper
from langchain.utilities.serpapi import SerpAPIWrapper from langchain.utilities.serpapi import SerpAPIWrapper
from langchain.utilities.spark_sql import SparkSQL from langchain.utilities.spark_sql import SparkSQL
from langchain.utilities.sql_database import SQLDatabase from langchain.utilities.sql_database import SQLDatabase
from langchain.utilities.tensorflow_datasets import TensorflowDatasets
from langchain.utilities.twilio import TwilioAPIWrapper from langchain.utilities.twilio import TwilioAPIWrapper
from langchain.utilities.wikipedia import WikipediaAPIWrapper from langchain.utilities.wikipedia import WikipediaAPIWrapper
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
@ -62,6 +63,7 @@ __all__ = [
"SearxSearchWrapper", "SearxSearchWrapper",
"SerpAPIWrapper", "SerpAPIWrapper",
"SparkSQL", "SparkSQL",
"TensorflowDatasets",
"TextRequestsWrapper", "TextRequestsWrapper",
"TextRequestsWrapper", "TextRequestsWrapper",
"TwilioAPIWrapper", "TwilioAPIWrapper",

@ -21,7 +21,7 @@ class ArxivAPIWrapper(BaseModel):
It limits the Document content by doc_content_chars_max. It limits the Document content by doc_content_chars_max.
Set doc_content_chars_max=None if you don't want to limit the content size. Set doc_content_chars_max=None if you don't want to limit the content size.
Args: Attributes:
top_k_results: number of the top-scored document used for the arxiv tool top_k_results: number of the top-scored document used for the arxiv tool
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool. ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
load_max_docs: a limit to the number of loaded documents load_max_docs: a limit to the number of loaded documents

@ -0,0 +1,111 @@
import logging
from typing import Any, Callable, Dict, Iterator, List, Optional
from pydantic import BaseModel, root_validator
from langchain.schema import Document
logger = logging.getLogger(__name__)
class TensorflowDatasets(BaseModel):
"""Access to the TensorFlow Datasets.
The Current implementation can work only with datasets that fit in a memory.
`TensorFlow Datasets` is a collection of datasets ready to use, with TensorFlow
or other Python ML frameworks, such as Jax. All datasets are exposed
as `tf.data.Datasets`.
To get started see the Guide: https://www.tensorflow.org/datasets/overview and
the list of datasets: https://www.tensorflow.org/datasets/catalog/
overview#all_datasets
You have to provide the sample_to_document_function: a function that
a sample from the dataset-specific format to the Document.
Attributes:
dataset_name: the name of the dataset to load
split_name: the name of the split to load. Defaults to "train".
load_max_docs: a limit to the number of loaded documents. Defaults to 100.
sample_to_document_function: a function that converts a dataset sample
to a Document
Example:
.. code-block:: python
from langchain.utilities import TensorflowDatasets
def mlqaen_example_to_document(example: dict) -> Document:
return Document(
page_content=decode_to_str(example["context"]),
metadata={
"id": decode_to_str(example["id"]),
"title": decode_to_str(example["title"]),
"question": decode_to_str(example["question"]),
"answer": decode_to_str(example["answers"]["text"][0]),
},
)
tsds_client = TensorflowDatasets(
dataset_name="mlqa/en",
split_name="train",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
"""
dataset_name: str = ""
split_name: str = "train"
load_max_docs: int = 100
sample_to_document_function: Optional[Callable[[Dict], Document]] = None
dataset: Any #: :meta private:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
try:
import tensorflow # noqa: F401
except ImportError:
raise ImportError(
"Could not import tensorflow python package. "
"Please install it with `pip install tensorflow`."
)
try:
import tensorflow_datasets
except ImportError:
raise ImportError(
"Could not import tensorflow_datasets python package. "
"Please install it with `pip install tensorflow-datasets`."
)
if values["sample_to_document_function"] is None:
raise ValueError(
"sample_to_document_function is None. "
"Please provide a function that converts a dataset sample to"
" a Document."
)
values["dataset"] = tensorflow_datasets.load(
values["dataset_name"], split=values["split_name"]
)
return values
def lazy_load(self) -> Iterator[Document]:
"""Download a selected dataset lazily.
Returns: an iterator of Documents.
"""
return (
self.sample_to_document_function(s)
for s in self.dataset.take(self.load_max_docs)
if self.sample_to_document_function is not None
)
def load(self) -> List[Document]:
"""Download a selected dataset.
Returns: a list of Documents.
"""
return list(self.lazy_load())

@ -0,0 +1,105 @@
"""Integration tests for the TensorFlow Dataset Loader."""
import pytest
from pydantic.error_wrappers import ValidationError
from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader
from langchain.schema.document import Document
# adding tensorflow and tensorflow_datasets to pyproject.toml is not working
# these tests can be run in isolation only
tensorflow = pytest.importorskip("tensorflow")
tensorflow_datasets = pytest.importorskip("tensorflow_datasets")
# placed here after checking for tensorflow package installation
import tensorflow as tf # noqa: E402
def decode_to_str(item: tf.Tensor) -> str:
return item.numpy().decode("utf-8")
def mlqaen_example_to_document(example: dict) -> Document:
return Document(
page_content=decode_to_str(example["context"]),
metadata={
"id": decode_to_str(example["id"]),
"title": decode_to_str(example["title"]),
"question": decode_to_str(example["question"]),
"answer": decode_to_str(example["answers"]["text"][0]),
},
)
MAX_DOCS = 10
@pytest.fixture
def tfds_client() -> TensorflowDatasetLoader:
return TensorflowDatasetLoader(
dataset_name="mlqa/en",
split_name="test",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
def test_load_success(tfds_client: TensorflowDatasetLoader) -> None:
"""Test that returns the correct answer"""
output = tfds_client.load()
assert isinstance(output, list)
assert len(output) == MAX_DOCS
assert isinstance(output[0], Document)
assert len(output[0].page_content) > 0
assert isinstance(output[0].page_content, str)
assert isinstance(output[0].metadata, dict)
def test_lazy_load_success(tfds_client: TensorflowDatasetLoader) -> None:
"""Test that returns the correct answer"""
output = list(tfds_client.lazy_load())
assert isinstance(output, list)
assert len(output) == MAX_DOCS
assert isinstance(output[0], Document)
assert len(output[0].page_content) > 0
assert isinstance(output[0].page_content, str)
assert isinstance(output[0].metadata, dict)
def test_load_fail_wrong_dataset_name() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasetLoader(
dataset_name="wrong_dataset_name",
split_name="test",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
assert "the dataset name is spelled correctly" in str(exc_info.value)
def test_load_fail_wrong_split_name() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasetLoader(
dataset_name="mlqa/en",
split_name="wrong_split_name",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
assert "Unknown split" in str(exc_info.value)
def test_load_fail_no_func() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasetLoader(
dataset_name="mlqa/en",
split_name="test",
load_max_docs=MAX_DOCS,
)
assert "Please provide a function" in str(exc_info.value)

@ -0,0 +1,90 @@
"""Integration tests for the TensorFlow Dataset client."""
import pytest
import tensorflow as tf
from pydantic.error_wrappers import ValidationError
from langchain.schema.document import Document
from langchain.utilities.tensorflow_datasets import TensorflowDatasets
# adding tensorflow and tensorflow_datasets to pyproject.toml is not working
# these tests can be tested in isolation only
tensorflow = pytest.importorskip("tensorflow")
tensorflow_datasets = pytest.importorskip("tensorflow_datasets")
def decode_to_str(item: tf.Tensor) -> str:
return item.numpy().decode("utf-8")
def mlqaen_example_to_document(example: dict) -> Document:
return Document(
page_content=decode_to_str(example["context"]),
metadata={
"id": decode_to_str(example["id"]),
"title": decode_to_str(example["title"]),
"question": decode_to_str(example["question"]),
"answer": decode_to_str(example["answers"]["text"][0]),
},
)
MAX_DOCS = 10
@pytest.fixture
def tfds_client() -> TensorflowDatasets:
return TensorflowDatasets(
dataset_name="mlqa/en",
split_name="test",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
def test_load_success(tfds_client: TensorflowDatasets) -> None:
"""Test that returns the correct answer"""
output = tfds_client.load()
assert isinstance(output, list)
assert len(output) == MAX_DOCS
assert isinstance(output[0], Document)
assert len(output[0].page_content) > 0
assert isinstance(output[0].page_content, str)
assert isinstance(output[0].metadata, dict)
def test_load_fail_wrong_dataset_name() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasets(
dataset_name="wrong_dataset_name",
split_name="test",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
assert "the dataset name is spelled correctly" in str(exc_info.value)
def test_load_fail_wrong_split_name() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasets(
dataset_name="mlqa/en",
split_name="wrong_split_name",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
assert "Unknown split" in str(exc_info.value)
def test_load_fail_no_func() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasets(
dataset_name="mlqa/en",
split_name="test",
load_max_docs=MAX_DOCS,
)
assert "Please provide a function" in str(exc_info.value)
Loading…
Cancel
Save