问题
I am trying to build a distributed custom training loop in TensorFlow 2.0, but I can't figure out how to annotate the autograph tf.function signature in order to avoid retracing.
I have tried to use DatasetSpec and various combinations of TensorSpec tuples, but I get all sorts of errors.
My question
Is it possible to specify a tf.function input signature that accepts batched distributed datasets?
Minimal reproducing code
import tensorflow as tf
from tensorflow import keras
import numpy as np
class SimpleModel(keras.layers.Layer):
def __init__(self, name='simple_model', **kwargs):
super(SimpleModel, self).__init__(name=name, **kwargs)
self.w = self.add_weight(shape=(1, 1),
initializer=tf.constant_initializer(5.0),
trainable=True,
dtype=np.float32,
name='w')
def call(self, x):
return tf.matmul(x, self.w)
class Trainer:
def __init__(self):
self.mirrored_strategy = tf.distribute.MirroredStrategy()
with self.mirrored_strategy.scope():
self.simple_model = SimpleModel()
self.optimizer = tf.optimizers.Adam(learning_rate=0.01)
def train_batches(self, dataset):
dataset_dist = self.mirrored_strategy.experimental_distribute_dataset(dataset)
with self.mirrored_strategy.scope():
loss = self.train_batches_dist(dataset_dist)
return loss.numpy()
@tf.function(input_signature=(tf.data.DatasetSpec(element_spec=tf.TensorSpec(shape=(None, 1), dtype=tf.float32)),))
def train_batches_dist(self, dataset_dist):
total_loss = 0.0
for batch in dataset_dist:
losses = self.mirrored_strategy.experimental_run_v2(
Trainer.train_batch, args=(self, batch)
)
mean_loss = self.mirrored_strategy.reduce(tf.distribute.ReduceOp.MEAN, losses, axis=0)
total_loss += mean_loss
return total_loss
def train_batch(self, batch):
with tf.GradientTape() as tape:
losses = tf.square(2 * batch - self.simple_model(batch))
gradients = tape.gradient(losses, self.simple_model.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.simple_model.trainable_weights))
return losses
def main():
values = np.random.sample((100, 1)).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(values)
dataset = dataset.batch(10)
trainer = Trainer()
for epoch in range(0, 100):
loss = trainer.train_batches(dataset)
print(loss / 10.0)
if __name__ == '__main__':
main()
Error message
TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
来源:https://stackoverflow.com/questions/58484924/tf-function-input-signature-for-distributed-dataset-in-tensorflow-2-0