Loss explosion while training a custom model 💨
Detailed Intuition with Explanation
One of a few great articles providing base to this content is Exploding gradients in neural networks.
So, if you’ve already started training your network/model, you must know about loss
.
It’s mainly numerical form of saying that the model’s predictions varied by ‘x’ amount of value. And this is the loss. And loss functions can be a variety of them, choose what suits your data
.
So,
Now,
Loss explosion can happen due to two events:
- Number of classes conflict in .config file and
labelmap.pb
file
Either it has made a prediction that is not present in training data i.e., it is very far from actual data. That means in yourlabel map
(provided reference in training .config file) has less number of classes than described in .config file. So model will be trained to identify 80 classes for example from the dataset that contains say 1 object/class.
Therefore model will predict [1–99] predictions but every-time during back propagation, it receives higher loss from optimisers for [2–99] values and only feasible loss to [1] class.
That’s why loss adds up exponentially for all 98 classes hence explodes.
Graphically, it looks like
Instead of
Here the Gradient
term is referred to by how much the weight values should be modified and in which direction to match the ground truth class/value.
2. The Optimiser gets confused
If training learning rate is too much according the the dataset.
Say if your data is simple like identifying MNIST dataset. We’ll require very low learning rate so that our optimiser can come up with gradients that’ll bring the predictions towards the global minima. And if larger learning rate is provided, on the next prediction, the weights are going to take decision towards the higher values pre-assigned and it learned false information.
On the next iteration, it again crosses the minima and to the other side.
Graphically,
We are following the yellow line.
Here is another useful graphical in-detail intuition and here is a great machine learning course where you can also learn about debugging your model.
Implementation,
Before:
learning_rate_base: 0.800000011920929
warmup_learning_rate: 0.13333000242710114
After:learning_rate_base: .008
warmup_learning_rate: 0.0013333
Here is my reply in official issue
That’s all. See you in the next post 🤩