Create a mixed data generator (images,csv) in keras

前端 未结 2 1120
盖世英雄少女心
盖世英雄少女心 2020-12-16 20:43

I am building a model with multiple inputs as shown in pyimagesearch, however I can\'t load all images into RAM and I am trying to create a generator that uses flow_fr

相关标签:
2条回答
  • 2020-12-16 21:10

    I found a solution based on Luke's answer using a custom generator

    import random
    import pandas as pd
    import numpy as np
    from glob import glob
    from keras.preprocessing import image as krs_image
    
    # Create the arguments for image preprocessing
    data_gen_args = dict(
        horizontal_flip=True,
        brightness_range=[0.5, 1.5],
        shear_range=10,
        channel_shift_range=50,
        rescale=1. / 255,
    )
    
    # Create an empty data generator
    datagen = ImageDataGenerator()
    
    # Read the image list and csv
    image_file_list = glob(f'{images_dir}/{split}/**/*.JPG', recursive=True)
    df = pd.read_csv(f'{csv_dir}/{split}.csv', index_col=csv_data[0])
    random.shuffle(image_file_list)
    
    def custom_generator(images_list, dataframe, batch_size):
        i = 0
        while True:
            batch = {'images': [], 'csv': [], 'labels': []}
            for b in range(batch_size):
                if i == len(images_list):
                    i = 0
                    random.shuffle(images_list)
                # Read image from list and convert to array
                image_path = images_list[i]
                image_name = os.path.basename(image_path).replace('.JPG', '')
                image = krs_image.load_img(image_path, target_size=(img_height, img_width))
                image = datagen.apply_transform(image, data_gen_args)
                image = krs_image.img_to_array(image)
    
                # Read data from csv using the name of current image
                csv_row = dataframe.loc[image_name, :]
                label = csv_row['class']
                csv_features = csv_row.drop(labels='class')
    
                batch['images'].append(image)
                batch['csv'].append(csv_features)
                batch['labels'].append(label)
    
                i += 1
    
            batch['images'] = np.array(batch['images'])
            batch['csv'] = np.array(batch['csv'])
            # Convert labels to categorical values
            batch['labels'] = np.eye(num_classes)[batch['labels']]
    
            yield [batch['images'], batch['csv']], batch['labels']
    
    0 讨论(0)
  • 2020-12-16 21:27

    I would suggest creating a custom generator given this relatively specific case. Something like the following (modified from a similar answer here) should suffice:

    import os
    import random
    import pandas as pd
    
    def generator(image_dir, csv_dir, batch_size):
        i = 0
        image_file_list = os.listdir(image_dir)
        while True:
            batch_x = {'images': list(), 'other_feats': list()}  # use a dict for multiple inputs
            batch_y = list()
            for b in range(batch_size):
                if i == len(image_file_list):
                    i = 0
                    random.shuffle(image_file_list)
                sample = image_file_list[i]
                image_file_path = sample[0]
                csv_file_path = os.path.join(csv_dir,
                                             os.path.basename(image_file_path).replace('.png', '.csv'))
                i += 1
                image = preprocess_image(cv2.imread(image_file_path))
                csv_file = pd.read_csv(csv_file_path)
                other_feat = preprocess_feats(csv_file)
                batch_x['images'].append(image)
                batch_x['other_feats'].append(other_feat)
                batch_y.append(csv_file.loc[image_name, :]['class'])
    
            batch_x['images'] = np.array(batch_x['images'])  # convert each list to array
            batch_x['other_feats'] = np.array(batch_x['other_feats'])
            batch_y = np.eye(num_classes)[batch['labels']]
            yield batch_x, batch_y
    

    Then, you can use Keras's fit_generator() function to train your model.

    Obviously, this assumes you have csv files with the same names as your image files, and that you have some custom preprocessing functions for images and csv files.

    0 讨论(0)
提交回复
热议问题