I have a huge TFRecord file with more than 4M entries. It is a very unbalanced dataset containing many more entries of some labels and few others - compare to the whole dataset. I want to filter a limited number of entries of some of these labels in order to have a balanced dataset. Below, you can see my attempt, but it takes more than 24 hours to filter 1k from each label (33 different labels).
import tensorflow as tf tf.compat.as_str( bytes_or_text='str', encoding='utf-8' ) try: tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() print("Device:", tpu.master()) strategy = tf.distribute.TPUStrategy(tpu) except: strategy = tf.distribute.get_strategy() print("Number of replicas:", strategy.num_replicas_in_sync) ignore_order = tf.data.Options() ignore_order.experimental_deterministic = False dataset = tf.data.TFRecordDataset('/test.tfrecord') dataset = dataset.with_options(ignore_order) features, feature_lists = detect_schema(dataset) #Decodings TFRecord serialized data def decode_data(serialized): X, y = tf.io.parse_single_sequence_example( serialized, context_features=features, sequence_features=feature_lists) return X['title'], y['subject'] dataset = dataset.map(lambda x: tf.py_function(func=decode_data, inp=[x], Tout=(tf.string, tf.string))) #Filtering and concatenating the samples def balanced_dataset(dataset, labels_list, sample_size=1000): datasets_list = [] for label in labels_list: #Filtering the chosen labels locals()[label] = dataset.filter(lambda x, y: tf.greater(tf.reduce_sum(tf.cast(tf.equal(tf.constant(label, dtype=tf.int64), y), tf.float32)), tf.constant(0.))) #appending a limited sample datasets_list.append(locals()[label].take(sample_size)) concat_dataset = datasets_list[0] #concatenating the datasets for dset in datasets_list[1:]: concat_dataset = concat_dataset.concatenate(dset) return concat_dataset balanced_data = balanced_dataset(tabledataset, labels_list=list(decod_dic.values()), sample_size=1000)