In this study, we systematically investigate the impact of class imbalance on classification performance of convolutional neural networks and compare frequently used methods to address the issue. Class imbalance refers to a significantly different number of examples among classes in a training set. It is a common problem that has been comprehensively studied in classical machine learning, yet very limited systematic research is available in the context of deep learning.

Experiments

We define and parameterize two representative types of imbalance, i.e. step and linear. Using three benchmark datasets of increasing complexity, MNIST, CIFAR-10 and ImageNet, we investigate the effects of imbalance on classification and perform an extensive comparison of several methods to address the issue: oversampling, undersampling, two-phase training, and thresholding that compensates for prior class probabilities. Our main evaluation metric is area under the receiver operating characteristic curve (ROC AUC) adjusted to multi-class tasks since overall accuracy metric is associated with notable difficulties in the context of imbalanced data.

MNIST CIFAR-10 ImageNet
MNIST CIFAR-10 ImageNet
Figure 1: Examples from datasets used in experiments.

Dataset Network Code Paper
MNIST LeNet-5 BVLC umontreal.ca
CIFAR-10 All-CNN mateuszbuda arXiv
ImageNet ResNet-10 cvjena arXiv
Table 1: Networks used in experiments.

Conclusions

Based on results from our experiments we conclude that

  • the effect of class imbalance on classification performance is detrimental and increases with the extent of imbalance and the scale of a task;
  • the method of addressing class imbalance that emerged as dominant in almost all analyzed scenarios was oversampling;
  • oversampling should be applied to the level that eliminates the imbalance, whereas undersampling can perform better when the imbalance is only removed to some extent;
  • thresholding should be applied to compensate for prior class probabilities when overall number of properly classified cases is of interest;
  • as opposed to some classical machine learning models, oversampling does not necessarily cause overfitting of convolutional neural networks.