I got the following model, an autoencoder with feedback to the input layers of both, encoder and decoder, which is very slow at the moment due to the for loop. But it seems TOO slow, even for that case. Is it possible to speed up the inference/training?
The model is:
class FRAE(tf.keras.Model): def __init__(self, latent_dim, shape, ht, n1, n2, n3, n4, bypass=False, trainable=True,**kwargs): super(FRAE, self).__init__(**kwargs) self.latent_dim = latent_dim self.shape = shape self.ht = ht self.buffer = tf.Variable(initial_value=tf.zeros(shape=(1,shape[0] * self.ht), dtype=tf.float32)) self.bypass = bypass self.quantizer = None self.trainable = trainable self.l1 = tf.keras.layers.Dense(n1, activation='swish', input_shape=shape) self.l2 = tf.keras.layers.Dense(n1, activation='swish') self.ls = tf.keras.layers.Dense(latent_dim, activation='swish') self.l3 = tf.keras.layers.Dense(n3, activation='swish') self.l4 = tf.keras.layers.Dense(n4, activation='swish') self.l5 = tf.keras.layers.Dense(shape[-1], activation='linear') def get_config(self): config = super(FRAE,self).get_config().copy() config.update({'latent_dim':self.latent_dim, 'bypass':self.bypass, 'quantizer':self.quantizer, "encoder":self.encoder, "buffer":self.buffer, 'decoder':self.decoder,"ht":self.ht, "shape":self.shape, "name":self.name}) return config def update_buffer(self, new_element): n = self.shape[0] new_element_expanded = tf.expand_dims(new_element, axis=0) self.buffer.assign(tf.keras.backend.concatenate([new_element_expanded, self.buffer[:, :-n]], axis=1)) def resetBuffer(self): self.buffer[:,:].assign(tf.zeros(shape=(1,self.shape[0] * self.ht), dtype=tf.float32)) @tf.function def call(self, x): if self.bypass is True: print("Bypassing FRAE", flush=True) return x else: x = tf.squeeze(x,axis=0) decoded = tf.TensorArray(tf.float32, size=tf.shape(x)[0]) for i in tf.range(tf.shape(x)[0]): xexpand = tf.expand_dims(x[i],axis=0) xin = tf.concat((xexpand, self.buffer), axis=1) encoded = self.ls(self.l2(self.l1(xin))) decin = tf.concat([encoded, self.buffer], axis=1) y = self.l5(self.l4(self.l3(decin))) decoded = decoded.write(i,y) i += 1 self.update_buffer(tf.squeeze(y)) tmp = tf.transpose(decoded.stack(),[1,0,2]) return tmp