Load tensorflow images and create patches

我与影子孤独终老i 提交于 2021-02-04 08:11:30

问题


I am using image_dataset_from_directory to load a very large RGB imagery dataset from disk into a Dataset. For example,

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    <directory>,
    label_mode=None,
    seed=1,
    subset='training',
    validation_split=0.1)

The Dataset has, say, 100000 images grouped into batches of size 32 yielding a tf.data.Dataset with spec (batch=32, width=256, height=256, channels=3)

I would like to extract patches from the images to create a new tf.data.Dataset with image spatial dimensions of, say, 64x64.

Therefore, I would like to create a new Dataset with 400000 patches still in batches of 32 with a tf.data.Dataset with spec (batch=32, width=64, height=64, channels=3)

I've looked at the window method and the extract_patches function but it's not clear from the documentation how to use them to create a new Dataset I need to start training on the patches. The window seems to be geared toward 1D tensors and the extract_patches seems to work with arrays and not with Datasets.

Any suggestions on how to accomplish this?

UPDATE:

Just to clarify my needs. I am trying to avoid manually creating the patches on disk. One, that would be untenable disk wise. Two, the patch size is not fixed. The experiments will be conducted over several patch sizes. So, I do not want to manually perform the patch creation either on disk or manually load the images in memory and perform the patching. I would prefer to have tensorflow handle the patch creation as part of the pipeline workflow to minimize disk and memory usage.


回答1:


What you're looking for is tf.image.extract_patches. Here's an example:

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

data = tfds.load('mnist', split='test', as_supervised=True)

get_patches = lambda x, y: (tf.reshape(
    tf.image.extract_patches(
        images=tf.expand_dims(x, 0),
        sizes=[1, 14, 14, 1],
        strides=[1, 14, 14, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'), (4, 14, 14, 1)), y)

data = data.map(get_patches)

fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(data))
for index, image in enumerate(images):
    ax = plt.subplot(2, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image)
plt.show()




回答2:


I believe you can use a python class generator. You can pass this generator to model.fit function if you want. I actually used it once for labels preprocessing.

I wrote the following dataset generator that loads a batch from your dataset, splits the images from the batch into multiple images based on the tile_shape parameter. If there are enough images, the next batch is returned.

In the example, I used a simple dataset from_tensor_slices for simplification. You can, of course, replace it with yours.

import tensorflow as tf

class TileDatasetGenerator:
    
    def __init__(self, dataset, batch_size, tile_shape):
        self.dataset_iterator = iter(dataset)
        self.batch_size = batch_size
        self.tile_shape = tile_shape
        self.image_queue = None
    
    def __iter__(self):
        return self
    
    def __next__(self):
        if self._has_queued_enough_for_batch():
            return self._dequeue_batch()
        
        batch = next(self.dataset_iterator)
        self._split_images(batch)    
        return self.__next__()
            
    def _has_queued_enough_for_batch(self):
        return self.image_queue is not None and tf.shape(self.image_queue)[0] >= self.batch_size
    
    def _dequeue_batch(self):
        batch, remainder = tf.split(self.image_queue, [self.batch_size, -1], axis=0)
        self.image_queue = remainder
        return batch
        
    def _split_images(self, batch):
        batch_shape = tf.shape(batch)
        batch_splitted = tf.reshape(batch, shape=[-1, self.tile_shape[0], self.tile_shape[1], batch_shape[-1]])
        if self.image_queue is None:
            self.image_queue = batch_splitted
        else:
            self.image_queue = tf.concat([self.image_queue, batch_splitted], axis=0)
            


dataset = tf.data.Dataset.from_tensor_slices(tf.ones(shape=[128, 64, 64, 3]))
dataset.batch(32)
generator = TileDatasetGenerator(dataset, batch_size = 16, tile_shape = [32,32])

for batch in generator:
    tf.print(tf.shape(batch))

Edit: It is possible to convert the generator to tf.data.Dataset if you want, but it requires that you add a __call__ function to the generator returning an iterator (self in this case).

new_dataset = tf.data.Dataset.from_generator(generator, output_types=(tf.int64))


来源:https://stackoverflow.com/questions/64326029/load-tensorflow-images-and-create-patches

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!