Photo by Markus Winkler on Unsplash

Machine Learning: Regularization Techniques

How to prevent a neural network from getting so good it’s bad.

Justin Masayda
3 min readJan 24, 2022

--

A sufficiently complex neural network can result in ~100% accuracy on the data it was trained with, but significant error on unfamiliar/validation data. When this occurs, the network is overfitting the training data. This means that it makes predictions that are too strongly attached to features it learned in training, but which don’t necessarily correlate with the expected results.

One way to temper overfitting is by using a process called regularization. Regularization generally works by penalizing a neural network for complexity. The less complex yet accurate the model, the lower the cost.

I’ll explain some common regularization techniques. If you’re just here for the intuition, feel free to skip the “technical explanation” sections.

L1 and L2 Regularization

Intuition

In order to improve on its task, a neural network needs to be “told” whether it’s making correct predictions or not, and if not, how far off it is. The degree to which it is wrong is called cost, and is calculated with a cost function.

L1 and L2 regularization work by adding additional cost to the cost function. That extra cost is proportional to the size of the network’s weights. If the weights are spread too widely, the cost will increase, and the training process will tend to influence the weights towards a central value.

Technical Explanation

L1 regularization uses the L1 norm as the penalty, which is defined as the sum of the absolute values of the weights. If W is a vector of n weights, the L1 norm of W is defined as:

||W||₁ = |w1| + |w2| + … +|wn|

L2 regularization uses the L2 norm, which is the square root of the sum of the squares of the weights:

||W||₂ = (w1² + w2² + … +wn²)

In practice, it isn’t necessary to take the square root, it just increases computational complexity with no benefit to the cost function, so the squared L2 norm is used instead:

||W||₂² = w1² + w2² + … +wn² = W · Wᵀ

Both methods cause the network to incur a penalty that increases the father the weights are from 0 (whether positive or negative).

So the cost function with regularization will use some ratio of the L1 or L2 norm controlled by the regularization parameter (a hyperparameter usually denoted λ), which controls how much regularization to apply.

The formula for the cost function using L2 normalization is:

Cost + λ / 2m · ||W||₂²

where m is the number of training examples.

[Why divide by 2m?]

Dropout

Intuition

Another way to reduce complexity is to reduce the number of nodes in the network. This is effectively what dropout accomplishes. Dropout is the technique of randomly disabling nodes throughout training, which effectively makes the network behave more like a system of simple networks, rather than a single overly-complex network. Dropout may be tuned by a hyperparameter which defines the probability of a neuron being “dropped.”

Technical Explanation

Data Augmentation

Overfitting can also be reduced by training over a larger variety of data. This is the role of data augmentation. Data augmentation is the process of applying various relevant transformations to the input data so that the network has a wider variety material to learn from. It can be much quicker and less expensive to modify data than to acquire it from scratch.

Technical Explanation

Early Stopping

Finally, there’s one more common technique to reduce overfitting: early stopping. Eventually, a network will reach a point where it isn’t getting much more accurate. If the network is trained beyond this point, it may only increase accuracy on the training data and begin decreasing in general accuracy. For this reason, it’s a good idea to stop training once a certain amount of stagnation has occurred. That’s the idea behind early stopping.

Technical Explanation

Now you know a few common regularization techniques. When theses techniques are used, neural networks can become significantly more accurate.

--

--

Justin Masayda

Software engineer | Machine learning specialist | Learning audio programming | Jazz pianist | Electronic music producer