langchain/libs/community/langchain_community/utilities/tensorflow_datasets.py
Erick Friis c2a3021bb0
multiple: pydantic 2 compatibility, v0.3 (#26443)
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com>
Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com>
Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com>
Co-authored-by: ZhangShenao <15201440436@163.com>
Co-authored-by: Friso H. Kingma <fhkingma@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
Co-authored-by: Morgante Pell <morgantep@google.com>
2024-09-13 14:38:45 -07:00

112 lines
3.9 KiB
Python

import logging
from typing import Any, Callable, Dict, Iterator, List, Optional
from langchain_core.documents import Document
from pydantic import BaseModel, model_validator
logger = logging.getLogger(__name__)
class TensorflowDatasets(BaseModel):
"""Access to the TensorFlow Datasets.
The Current implementation can work only with datasets that fit in a memory.
`TensorFlow Datasets` is a collection of datasets ready to use, with TensorFlow
or other Python ML frameworks, such as Jax. All datasets are exposed
as `tf.data.Datasets`.
To get started see the Guide: https://www.tensorflow.org/datasets/overview and
the list of datasets: https://www.tensorflow.org/datasets/catalog/
overview#all_datasets
You have to provide the sample_to_document_function: a function that
a sample from the dataset-specific format to the Document.
Attributes:
dataset_name: the name of the dataset to load
split_name: the name of the split to load. Defaults to "train".
load_max_docs: a limit to the number of loaded documents. Defaults to 100.
sample_to_document_function: a function that converts a dataset sample
to a Document
Example:
.. code-block:: python
from langchain_community.utilities import TensorflowDatasets
def mlqaen_example_to_document(example: dict) -> Document:
return Document(
page_content=decode_to_str(example["context"]),
metadata={
"id": decode_to_str(example["id"]),
"title": decode_to_str(example["title"]),
"question": decode_to_str(example["question"]),
"answer": decode_to_str(example["answers"]["text"][0]),
},
)
tsds_client = TensorflowDatasets(
dataset_name="mlqa/en",
split_name="train",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
"""
dataset_name: str = ""
split_name: str = "train"
load_max_docs: int = 100
sample_to_document_function: Optional[Callable[[Dict], Document]] = None
dataset: Any #: :meta private:
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that the python package exists in environment."""
try:
import tensorflow # noqa: F401
except ImportError:
raise ImportError(
"Could not import tensorflow python package. "
"Please install it with `pip install tensorflow`."
)
try:
import tensorflow_datasets
except ImportError:
raise ImportError(
"Could not import tensorflow_datasets python package. "
"Please install it with `pip install tensorflow-datasets`."
)
if values["sample_to_document_function"] is None:
raise ValueError(
"sample_to_document_function is None. "
"Please provide a function that converts a dataset sample to"
" a Document."
)
values["dataset"] = tensorflow_datasets.load(
values["dataset_name"], split=values["split_name"]
)
return values
def lazy_load(self) -> Iterator[Document]:
"""Download a selected dataset lazily.
Returns: an iterator of Documents.
"""
return (
self.sample_to_document_function(s)
for s in self.dataset.take(self.load_max_docs)
if self.sample_to_document_function is not None
)
def load(self) -> List[Document]:
"""Download a selected dataset.
Returns: a list of Documents.
"""
return list(self.lazy_load())