Add a resizing layer to a keras sequential model

后端 未结 5 885
失恋的感觉
失恋的感觉 2020-12-17 18:02

How can I add a resizing layer to

model = Sequential()

using

model.add(...)

To resize an image from sha

5条回答
  •  借酒劲吻你
    2020-12-17 18:52

    A modification of @KeithWM 's answer, adding output_scale, e.g. output_scale=2 means the output is 2 times the input shape :)

    class ResizeImages(Layer):
        """Resize Images to a specified size
        https://stackoverflow.com/questions/41903928/add-a-resizing-layer-to-a-keras-sequential-model
    
        # Arguments
            output_dim: Size of output layer width and height
            output_scale: scale compared with input
            data_format: A string,
                one of `channels_last` (default) or `channels_first`.
                The ordering of the dimensions in the inputs.
                `channels_last` corresponds to inputs with shape
                `(batch, height, width, channels)` while `channels_first`
                corresponds to inputs with shape
                `(batch, channels, height, width)`.
                It defaults to the `image_data_format` value found in your
                Keras config file at `~/.keras/keras.json`.
                If you never set it, then it will be "channels_last".
    
        # Input shape
            - If `data_format='channels_last'`:
                4D tensor with shape:
                `(batch_size, rows, cols, channels)`
            - If `data_format='channels_first'`:
                4D tensor with shape:
                `(batch_size, channels, rows, cols)`
    
        # Output shape
            - If `data_format='channels_last'`:
                4D tensor with shape:
                `(batch_size, pooled_rows, pooled_cols, channels)`
            - If `data_format='channels_first'`:
                4D tensor with shape:
                `(batch_size, channels, pooled_rows, pooled_cols)`
        """
    
        def __init__(self, output_dim=(1, 1), output_scale=None, data_format=None, **kwargs):
            super(ResizeImages, self).__init__(**kwargs)
            data_format = normalize_data_format(data_format)  # does not have
            self.naive_output_dim = conv_utils.normalize_tuple(output_dim,
                                                               2, 'output_dim')
            self.naive_output_scale = output_scale
            self.data_format = normalize_data_format(data_format)
            self.input_spec = InputSpec(ndim=4)
    
        def build(self, input_shape):
            self.input_spec = [InputSpec(shape=input_shape)]
            if self.naive_output_scale is not None:
                if self.data_format == 'channels_first':
                    self.output_dim = (self.naive_output_scale * input_shape[2],
                                       self.naive_output_scale * input_shape[3])
                elif self.data_format == 'channels_last':
                    self.output_dim = (self.naive_output_scale * input_shape[1],
                                       self.naive_output_scale * input_shape[2])
            else:
                self.output_dim = self.naive_output_dim
    
        def compute_output_shape(self, input_shape):
            if self.data_format == 'channels_first':
                return (input_shape[0], input_shape[1], self.output_dim[0], self.output_dim[1])
            elif self.data_format == 'channels_last':
                return (input_shape[0], self.output_dim[0], self.output_dim[1], input_shape[3])
    
        def _resize_fun(self, inputs, data_format):
            try:
                assert keras.backend.backend() == 'tensorflow'
                assert self.data_format == 'channels_last'
            except AssertionError:
                print("Only tensorflow backend is supported for the resize layer and accordingly 'channels_last' ordering")
            output = tf.image.resize_images(inputs, self.output_dim)
            return output
    
        def call(self, inputs):
            output = self._resize_fun(inputs=inputs, data_format=self.data_format)
            return output
    
        def get_config(self):
            config = {'output_dim': self.output_dim,
                      'padding': self.padding,
                      'data_format': self.data_format}
            base_config = super(ResizeImages, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))
    

提交回复
热议问题