Dataset

When defining a dataset, a class that inherits tflibs.datasets.BaseDataset should be defined.

The following is an example of a definition.

>>> import os
>>> from tflibs.datasets import ImageSpec, LabelSpec
>>>
>>> class CatDogDataset(BaseDataset):
>>>     def __init__(self, dataset_dir, image_size):
>>>         self._image_size = image_size
>>>         BaseDataset.__init__(self, os.path.join(dataset_dir, 'cat_dog'))
>>>
>>>     @property
>>>     def tfrecord_filename(self):
>>>         return 'cat_dog.tfrecord'
>>>
>>>     def _init_feature_specs(self):
>>>         return {
>>>             'image': ImageSpec([self._image_size, self._image_size, 3]),
>>>             'label': LabelSpec(3, class_names=['Cat', 'Dog', 'Cookie'])
>>>         }
>>>
>>>     @classmethod
>>>     def add_arguments(cls, parser):
>>>         parser.add_argument('--image-size',
>>>                             type=int,
>>>                             default=128,
>>>                             help='The size of output image.')

When writing TF-record files, create dataset object and call write().

>>> dataset = CatDogDataset('/tmp/dataset', 64)
>>>
>>> images = ['/cat/abc.jpg', '/dog/def.jpg', '/cookie/ghi.jpg']
>>> labels = ['Cat', 'Dog', 'Cookie']
>>>
>>> def process_fn((image_path, label_str), feature_specs):
>>>     id_string = os.path.splitext(os.path.basename(image_path))[0]
>>>
>>>     def build_example(_id, image, label):
>>>         return {
>>>             '_id': _id.create_with_string(id_string),
>>>             'image': image.create_with_path(image_path),
>>>             'label': label.create_with_label(label_str),
>>>         }
>>>
>>>     return build_example(**feature_specs)
>>>
>>> dataset.write(zip(images, labels), process_fn)

When reading TF-record files, create dataset object and call read().

>>> dataset = CatDogDataset('/tmp/dataset', 64)
>>>
>>> # Returns a `tf.data.Dataset`
>>> # {
>>> #   '_id': {
>>> #       'dtype': tf.string,
>>> #       'shape': (),
>>> #   },
>>> #   'image': {
>>> #       'dtype': tf.uint8,
>>> #       'shape': [64, 64, 3],
>>> #   },
>>> #   'label': {
>>> #       'dtype': tf.int64,
>>> #       'shape': [3],
>>> #   }
>>> # }
>>> tfdataset = dataset.read()
class tflibs.datasets.dataset.BaseDataset(dataset_dir)[source]

A base class for defining a dataset

Parameters:dataset_dir (str) – A directory where tfrecord files are stored
classmethod add_arguments(parser)[source]

Adds arguments.

Called when tflibs.runner.DatasetInitializer creates a dataset object.

Parameters:parser (argparse.ArgumentParser) – Argument parser used to add arguments
feature_specs
read(split=None)[source]

Reads tfrecord and makes it tf.data.Dataset

Parameters:split – Split name (train or test)
Returns:A dataset
tfrecord_filename

It should return the name of TF-record file.

This should be implemented when defining a dataset.

Returns:TF-record filename
Return type:str
write(collection, process_fn, num_parallel_calls=16, test_size=None)[source]

Writes examples on tfrecord files

Parameters:
  • collection (list) –
  • process_fn (function) –
  • num_parallel_calls (int) –
  • test_size (int) –