What does the implementation of keras.losses.sparse_categorical_crossentropy look like?

本秂侑毒 提交于 2020-12-15 05:04:05

问题


I found tf.keras.losses.sparse_categorical_crossentropy is an amazing class that helps me create a loss function for a neural network that has a large number of output classes. Without this it is impossible to train the model, as I found tf.keras.losses.categorical_crossentropy gave an out-of-memory error because of converting an index into a 1-hot vector of very large size.

I, however, have a problem of understanding how sparse_categorical_crossentropy avoids the big memory issue. I took a look at the code from TF but it is indeed not easy to know what goes under the hood.

So, could anyone give some high-level idea of implementing this? What does the implementation look like? Thank you!


回答1:


It does not do anything special, it just produces the one-hot encoded labels inside the loss for a batch of data (not all data at the same time), when it is needed, and then discards the results. So its just a classic trade-off between memory and computation.




回答2:


The formula for categorical crossentropy is the following:

Where y_true is the ground truth data and y_pred is your model's predictions.

The bigger the dimensions of y_true and y_pred, more memory is necessary to perform all these operations.

But notice an interesting trick in this formula: only one of the neurons in y_true is 1, all the rest are zeros!!! This means we can assume that only one term in the sum is non-zero.

What a sparse formula does is:

  • Avoid the need to have a huge matrix for y_true, using only indices instead of one-hot encoding
  • Pick from y_pred only the column respective to the index, instead of performing calculations for the entire tensor.

So, the main idea of a sparse formula here is:

  • Gather columns from y_pred with the indices in y_true.
  • Calculate only the term -ln(y_pred_selected_columns)


来源:https://stackoverflow.com/questions/59577258/what-does-the-implementation-of-keras-losses-sparse-categorical-crossentropy-loo

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!