Multi class classification using InceptionV3,VGG16 with 101 classes very low accuracy

独自空忆成欢 提交于 2020-06-17 09:59:06

问题


I am trying to build a food classification model with 101 classes. The dataset has 1000 image for each class. The accuracy of the model which I trained is coming less than 6%. I have tried implementing NASNet and VGG16 with imagenet weights but the accuracy did not increase. I have tried using Adam optimizer with or without amsgrad. I have also tried to change the learning rate to both 0.01 and 0.0001 but still, accuracy remains in the single-digit.Please suggest the methods to increase the accuracy to at least 60 percent. Due to hardware restriction(Macbook air 2017) I cannot train very deep model.

Dataset: https://www.kaggle.com/kmader/food41

import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3


train_data_dir=".../food_data/images"

data=tf.keras.preprocessing.image.ImageDataGenerator(
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    zca_epsilon=1e-06,
    rotation_range=45,
    width_shift_range=0.2,
    height_shift_range=0.2,
    brightness_range=None,
    shear_range=0.2,
    zoom_range=0.2,
    channel_shift_range=0.0,
    fill_mode="nearest",
    cval=0.0,
    horizontal_flip=True,
    vertical_flip=True,
    rescale=1./255,
)
datagen=data.flow_from_directory(
        train_data_dir,
        target_size=(360, 360),
        batch_size=10,
        class_mode='categorical')


base_model = InceptionV3(weights='imagenet',input_shape=(360,360,3), include_top=False)

for layer in base_model.layers:
    layer.trainable = False


x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.3)(x)
x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dense(512, activation='relu')(x)
x = tf.keras.layers.Dense(256, activation='relu')(x)
predictions = tf.keras.layers.Dense(101, activation='softmax')(x)
model = tf.keras.models.Model(inputs=base_model.input, outputs=predictions)

adam=tf.keras.optimizers.Adam(
    learning_rate=0.001,
    beta_1=0.9,
    beta_2=0.999,
    epsilon=1e-07,
    amsgrad=False,
    name="Adam",
)

model.compile(optimizer='rmsprop', loss='categorical_crossentropy',metrics=['accuracy'])
model.fit_generator(datagen,steps_per_epoch=100,epochs=50)

model.save('trained_food_new.h5')

回答1:


There are few that may improve the classification accuracy:

  1. Use EfficientNet with noisy_student weights. There are less number of parameters to train. It gives better accuracy due to the scalable architecture it has.

  2. You can use test time augmentation. In your test data generator, do a simple horizontal flip, vertical flip (if data looks realistic) and affine transformations. It will generate multiple views of the data and helps the model to average out more probable class.

  3. Checkout imgaug library (embossing, sharpening, noise addition, etc.). Plus, there are random_eraser, cut out and mix up strategies that have been proved to be useful.

  4. Try label smoothing. It can also help your classifier to give more probability to the correct class.

  5. Try learning rate warmup. Something like this:

LR_START = 0.0001
LR_MAX = 0.00005
LR_MIN = 0.0001
LR_RAMPUP_EPOCHS = 4
LR_SUSTAIN_EPOCHS = 6
LR_EXP_DECAY = .8


def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr
  1. You can also extract features and apply ensemble feature classification(XGBoost, Adaboost, BaggingClassifier) or triplet loss.


来源:https://stackoverflow.com/questions/62307806/multi-class-classification-using-inceptionv3-vgg16-with-101-classes-very-low-acc

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!