tf.function input_signature for distributed dataset in tensorflow 2.0

∥☆過路亽.° 提交于 2020-07-22 09:32:12

问题


I am trying to build a distributed custom training loop in TensorFlow 2.0, but I can't figure out how to annotate the autograph tf.function signature in order to avoid retracing.

I have tried to use DatasetSpec and various combinations of TensorSpec tuples, but I get all sorts of errors.

My question

Is it possible to specify a tf.function input signature that accepts batched distributed datasets?

Minimal reproducing code

import tensorflow as tf
from tensorflow import keras
import numpy as np


class SimpleModel(keras.layers.Layer):
    def __init__(self, name='simple_model', **kwargs):
        super(SimpleModel, self).__init__(name=name, **kwargs)
        self.w = self.add_weight(shape=(1, 1),
                                 initializer=tf.constant_initializer(5.0),
                                 trainable=True,
                                 dtype=np.float32,
                                 name='w')

    def call(self, x):
        return tf.matmul(x, self.w)


class Trainer:
    def __init__(self):
        self.mirrored_strategy = tf.distribute.MirroredStrategy()

        with self.mirrored_strategy.scope():
            self.simple_model = SimpleModel()
            self.optimizer = tf.optimizers.Adam(learning_rate=0.01)

    def train_batches(self, dataset):
        dataset_dist = self.mirrored_strategy.experimental_distribute_dataset(dataset)

        with self.mirrored_strategy.scope():
            loss = self.train_batches_dist(dataset_dist)

        return loss.numpy()

    @tf.function(input_signature=(tf.data.DatasetSpec(element_spec=tf.TensorSpec(shape=(None, 1), dtype=tf.float32)),))
    def train_batches_dist(self, dataset_dist):
        total_loss = 0.0
        for batch in dataset_dist:
            losses = self.mirrored_strategy.experimental_run_v2(
                Trainer.train_batch, args=(self, batch)
            )
            mean_loss = self.mirrored_strategy.reduce(tf.distribute.ReduceOp.MEAN, losses, axis=0)

            total_loss += mean_loss
        return total_loss

    def train_batch(self, batch):
        with tf.GradientTape() as tape:
            losses = tf.square(2 * batch - self.simple_model(batch))

        gradients = tape.gradient(losses, self.simple_model.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.simple_model.trainable_weights))

        return losses


def main():
    values = np.random.sample((100, 1)).astype(np.float32)

    dataset = tf.data.Dataset.from_tensor_slices(values)
    dataset = dataset.batch(10)

    trainer = Trainer()
    for epoch in range(0, 100):
        loss = trainer.train_batches(dataset)
        print(loss / 10.0)


if __name__ == '__main__':
    main()

Error message

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'tensorflow.python.distribute.input_lib.DistributedDataset'>

来源:https://stackoverflow.com/questions/58484924/tf-function-input-signature-for-distributed-dataset-in-tensorflow-2-0

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!