Decode some Deep Learning phrases

  • Gradient Clipping
  • Gradient Checkpointing
  • Gradient Rollback
  • Gradient Accumulation

Gradient Clipping
Gradient clipping is a technique to prevent exploding gradients in very deep networks, usually in recurrent neural networks. A neural network is a learning algorithm, also called neural network or neural net, that uses a network of functions to understand and translate data input into a specific output.
With gradient clipping, pre-determined gradient threshold be introduced, and then gradients norms that exceed this threshold are scaled down to match the norm. This prevents any gradient to have norm greater than the threshold and thus the gradients are clipped.

Reference: https://deepai.org/machine-learning-glossary-and-terms/gradient-clipping#:~:text=Gradient%20clipping%20is%20a%20technique,input%20into%20a%20specific%20output.

Gradient Checkpointing
In a nutshell, gradient checkpointing works by recomputing the intermediate values of a deep neural net (which would ordinarily be stored at forward time) at backward time. This trades compute—the time cost of recalculating these values twice—for memory—the bandwidth cost of storing these values ahead of time.

Gradient checkpointing works by omitting some of the activation values from the computational graph. This reduces the memory used by the computational graph, reducing memory pressure overall (and allowing larger batch sizes in the process).

However, the reason that the activations are stored in the first place is that they are needed when calculating the gradient during backpropagation. Omitting them from the computational graph forces PyTorch to recalculate these values wherever they appear, slowing down computation overall.

Thus, gradient checkpointing is an example of one of the classic tradeoffs in computer science— that which exists between memory and compute.

Reference: spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs

Gradient Rollback:
Gradient Rollback reveals the training data that has the greatest influence on a prediction. Users can ascertain how plausible a prediction is by viewing its explanation (the training instances with the highest influence).
Reference: https://www.neclab.eu/blog/gradient-rollback#:~:text=Gradient%20Rollback%20reveals%20the%20training,instances%20with%20the%20highest%20influence).

Gradient Accumulation
Simply speaking, gradient accumulation means that we will use a small batch size but save the gradients and update network weights once every couple of batches.

When training a neural network, we usually divide our data in mini-batches and go through them one by one. The network predicts batch labels, which are used to compute the loss with respect to the actual targets. Next, we perform backward pass to compute gradients and update model weights in the direction of those gradients.

Gradient accumulation modifies the last step of the training process. Instead of updating the network weights on every batch, we can save gradient values, proceed to the next batch and add up the new gradients. The weight update is then done only after several batches have been processed by the model.

Gradient accumulation helps to imitate a larger batch size. Imagine you want to use 32 images in one batch, but your hardware crashes once you go beyond 8. In that case, you can use batches of 8 images and update weights once every 4 batches. If you accumulate gradients from every batch in between, the results will be (almost) the same and you will be able to perform training on a less expensive machine!
Reference:
https://kozodoi.me/python/deep%20learning/pytorch/tutorial/2021/02/19/gradient-accumulation.html#:~:text=Simply%20speaking%2C%20gradient%20accumulation%20means,might%20find%20this%20tutorial%20useful.

Advertisement

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s