How to do point-wise categorical crossentropy loss in Keras?

风流意气都作罢 提交于 2019-12-30 18:27:04

问题


I have a network that produces a 4D output tensor where the value at each position in spatial dimensions (~pixel) is to be interpreted as the class probabilities for that position. In other words, the output is (num_batches, height, width, num_classes). I have labels of the same size where the real class is coded as one-hot. I would like to calculate the categorical-crossentropy loss using this.

Problem #1: The K.softmax function expects a 2D tensor (num_batches, num_classes)

Problem #2: I'm not sure how the losses from each position should be combined. Is it correct to reshape the tensor to (num_batches * height * width, num_classes) and then calling K.categorical_crossentropy on that? Or rather, call K.categorical_crossentropy(num_batches, num_classes) height*width times and average the results?


回答1:


Found this issue to confirm my intuition.

In short : the softmax will take 2D or 3D inputs. If they are 3D keras will assume a shape like this (samples, timedimension, numclasses) and apply the softmax on the last one. For some weird reasons, it doesnt do that for 4D tensors.

Solution : reshape your output to a sequence of pixels

reshaped_output = Reshape((height*width, num_classes))(output_tensor)

Then apply your softmax

new_output = Activation('softmax')(reshaped_output) 

And then either you reshape your target tensors to 2D or you just reshape that last layer into (width, height, num_classes).

Otherwise, something I would try if i wasn't on my phone right now is to use a timedistributed(Activation('softmax')). But no idea if that would work... will try later

I hope this helps :-)




回答2:


Just flatten the output to a 2D tensor of size (num_batches, height * width * num_classes). You can do this with the Flatten layer. Ensure that your y is flattened the same way (normally calling y = y.reshape((num_batches, height * width * num_classes)) is enough).

For your second question, using categorical crossentropy over all width*height predictions is essentially the same as averaging the categorical crossentropy for each width*height predictions (by the definition of categorical crossentropy).




回答3:


You could also not reshape anything and define both softmax and loss on your own. Here is softmax which is applied to the last input dimension (like in tf backend):

def image_softmax(input):
    label_dim = -1
    d = K.exp(input - K.max(input, axis=label_dim, keepdims=True))
    return d / K.sum(d, axis=label_dim, keepdims=True)

and here you have loss (there is no need to reshape anything):

__EPS = 1e-5
def image_categorical_crossentropy(y_true, y_pred):
    y_pred = K.clip(y_pred, __EPS, 1 - __EPS)
    return -K.mean(y_true * K.log(y_pred) + (1 - y_true) * K.log(1 - y_pred))

No further reshapes need.




回答4:


It seems that now you can simply do softmax activation on the last Conv2D layer and then specify categorical_crossentropy loss and train on the image without any reshaping tricks or any new loss function. I've tried overfitting with a dummy dataset and it works well. Try it ~ !

inp = keras.Input(...)
# define your model here
out = keras.layers.Conv2D(classes, (1, 1), activation='softmax') (...)
model = keras.Model(inputs=[inp], outputs=[out], name='unet')
model.compile(loss='categorical_crossentropy',
                      optimizer='adam',
                      metrics=['accuracy'])
model.fit(tensor4d, tensor4d)

You can also compile using sparse_categorical_crossentropy and then train with output of shape (samples, height, width) where each pixel in the output corresponds to a class label: model.fit(tensor4d, tensor3d)

The idea is that softmax and categorical_crossentropy will be applied to the last axis (you can check keras.backend.softmax and keras.backend.categorical_crossentropy doc).

PS. I use keras from tensorflow.keras (tensorflow 2)

Update: I have trained on my real dataset and it is working as well.



来源:https://stackoverflow.com/questions/43033436/how-to-do-point-wise-categorical-crossentropy-loss-in-keras

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!