How can I explore and modify the created dataset from tf.keras.preprocessing.image_dataset_from_directory()?

穿精又带淫゛_ 提交于 2021-02-05 08:51:43

问题


Here's how I used the function:

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    main_directory,
    labels='inferred',
    image_size=(299, 299),
    validation_split=0.1,
    subset='training',
    seed=123
)

I'd like to explore the created dataset much like in this example, particularly the part where it was converted to a pandas dataframe. But my minimum goal is to check the labels and the number of files attached to it, just to check if, indeed, it created the dataset as expected (sub-directory being the corresponding label of images inside it).

To be clear, the main_directory is set up like this:

main_directory
- class_a
  - 000.jpg
  - ...
- class_b
  - 100.jpg
  - ...

And I'd like to see the dataset display its info with something like this:

label     number of images
class_a   100
class_b   100

Additionally, is it possible to remove labels and corresponding images in a dataset? The idea is to drop them if the corresponding number of images is less than a certain number, or a different metric. It can be of course done outside this function through other means, but I'd like to know if it is indeed possible, and if so, how.

EDIT: For additional context, the end goal of all of this is to train a pre-trained model like this with local images divided into folders named after their classes. If there is a better way that includes not using that function and meets this end goal, it's welcome all the same. Thanks!


回答1:


I think it would be much easier to use glob2 to get all your filenames, process them as you want to, then make a simple loading function that will replace image_dataset_from_directory.

Get all your files:

files = glob2.glob('class_*\\*.jpg')

Then manipulate this list of filenames as desired.

Then, make a function to load the images:

def load(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, size=(299, 299))
    label = tf.strings.split(file_path, os.sep)[0]
    label = tf.cast(tf.equal(label, 'class_a'), tf.int32)
    return img, label

Then create your dataset for training:

train_ds = tf.data.Dataset.from_tensor_slices(files).map(load).batch(4)

Then train:

model.fit(train_ds)


来源:https://stackoverflow.com/questions/64359945/how-can-i-explore-and-modify-the-created-dataset-from-tf-keras-preprocessing-ima

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!