![]() |
Base class for defining a parallel dataset using Python code.
tf.keras.utils.PyDataset( workers=1, use_multiprocessing=False, max_queue_size=10 )
Every PyDataset
must implement the __getitem__()
and the __len__()
methods. If you want to modify your dataset between epochs, you may additionally implement on_epoch_end()
. The __getitem__()
method should return a complete batch (not a single sample), and the __len__
method should return the number of batches in the dataset (rather than the number of samples).
Notes:
PyDataset
is a safer way to do multiprocessing. This structure guarantees that the model will only train once on each sample per epoch, which is not the case with Python generators.- The arguments
workers
,use_multiprocessing
, andmax_queue_size
exist to configure howfit()
uses parallelism to iterate over the dataset. They are not being used by thePyDataset
class directly. When you are manually iterating over aPyDataset
, no parallelism is applied.
Example:
from skimage.io import imread from skimage.transform import resize import numpy as np import math # Here, `x_set` is list of path to the images # and `y_set` are the associated classes. class CIFAR10PyDataset(keras.utils.PyDataset): def __init__(self, x_set, y_set, batch_size, **kwargs): super().__init__(**kwargs) self.x, self.y = x_set, y_set self.batch_size = batch_size def __len__(self): # Return number of batches. return math.ceil(len(self.x) / self.batch_size) def __getitem__(self, idx): # Return x, y for batch idx. low = idx * self.batch_size # Cap upper bound at array length; the last batch may be smaller # if the total number of items is not a multiple of batch size. high = min(low + self.batch_size, len(self.x)) batch_x = self.x[low:high] batch_y = self.y[low:high] return np.array([ resize(imread(file_name), (200, 200)) for file_name in batch_x]), np.array(batch_y)
Attributes | |
---|---|
max_queue_size | |
num_batches | Number of batches in the PyDataset. |
use_multiprocessing | |
workers |
Methods
on_epoch_end
on_epoch_end()
Method called at the end of every epoch.
__getitem__
__getitem__( index )
Gets batch at position index
.
Args | |
---|---|
index | position of the batch in the PyDataset. |
Returns | |
---|---|
A batch |