How efficiently filter a specific number of entries and concatenating them in a unique tf.data.Dataset

I have a huge TFRecord file with more than 4M entries. It is a very unbalanced dataset containing many more entries of some labels and few others - compare to the whole dataset. I want to filter a limited number of entries of some of these labels in order to have a balanced dataset. Below, you can see my attempt, but it takes more than 24 hours to filter 1k from each label (33 different labels).

import tensorflow as tf  tf.compat.as_str(     bytes_or_text='str', encoding='utf-8' ) try:     tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()     print("Device:", tpu.master())     strategy = tf.distribute.TPUStrategy(tpu) except:     strategy = tf.distribute.get_strategy() print("Number of replicas:", strategy.num_replicas_in_sync)  ignore_order = tf.data.Options() ignore_order.experimental_deterministic = False  dataset = tf.data.TFRecordDataset('/test.tfrecord') dataset = dataset.with_options(ignore_order) features, feature_lists = detect_schema(dataset)  #Decodings TFRecord serialized data def decode_data(serialized):     X, y = tf.io.parse_single_sequence_example(         serialized,         context_features=features,         sequence_features=feature_lists)     return X['title'], y['subject']  dataset = dataset.map(lambda x: tf.py_function(func=decode_data, inp=[x], Tout=(tf.string, tf.string)))  #Filtering and concatenating the samples  def balanced_dataset(dataset, labels_list, sample_size=1000):     datasets_list = []     for label in labels_list:         #Filtering the chosen labels         locals()[label] = dataset.filter(lambda x, y: tf.greater(tf.reduce_sum(tf.cast(tf.equal(tf.constant(label, dtype=tf.int64), y), tf.float32)), tf.constant(0.)))         #appending a limited sample         datasets_list.append(locals()[label].take(sample_size))     concat_dataset = datasets_list[0]     #concatenating the datasets     for dset in datasets_list[1:]:         concat_dataset = concat_dataset.concatenate(dset)     return concat_dataset    balanced_data = balanced_dataset(tabledataset, labels_list=list(decod_dic.values()), sample_size=1000) 

Hi @Marlon_Henrique_Teix,

Sorry for the delay in response.

I recommend using the map argument with num_parallel_calls=tf.data.AUTOTUNE to automatically find the optimal number of parallel calls, which can enhance data processing efficiency. Alternatively, you could use the group_by_window method from tf.data for each label and then concatenate the results. I’ve added a sample gist for this implementation with group_by_window method.

dataset = dataset.group_by_window(
key_func=key_func,
reduce_func=reduce_func,
window_size=window_size) #sample size of 1000

In addition, here’s a generic approach to effectively leverage tf.data pipelines using techniques such as prefetching, caching, interleaving, shuffling, and others. Please refer to this documentation.

Thank You.