Second order optimization methods
One line of thinking that's gained popularity in recent times is using second-order information from the loss to help prune or quantize the model. An example is hessian-based mixed precision quantization (the paper: HAWQ). The idea here was to use the hessian to tell us which layers contributed the most to maintaining low loss, and these layers in turn were deemed more 'valuable' and kept in higher precision. The reason for the hessian was it could tell us about the curvature of the loss curve at the optima along different axes - if there was a steep upward curve with respect to a certain layer's weights, any perturbation to those weights could cause a large jump in the loss. The computation was efficiently done by using the top eigenvalues of the hessian rather than computing the entire hessian for all parameters (memory of
This approach made me consider the idea of using second-order optimization methods in training. One idea that came to mind was as follows: take the hessian at a certain point in the parameter-loss space to estimate convexity. Let's say such a point exists (ie, hessian is positive semi-definite). Then, use a quadratic approximation to find an optima to jump to, and calculate the loss again at the new point. There are three main questions here:
- How to efficiently can we calculate the hessian?
- Are locally convex points easy to find?
- Is this method actually more efficient?
I did some literature review on this and came up with the following:
For computing the hessian:
- Reviewing complexity of hessian computation: https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/bishop-hessian-nc-92.pdf
- Iterative method to approximate Hessian for optimization (BFGS algorithm): https://en.wikipedia.org/wiki/Broyden–Fletcher–Goldfarb–Shanno_algorithm
- PyHessian: Library and paper from Kurt Keutzer group in 2020 at UC Berkeley - allows for efficient Hessian vector product calculation since the hessian itself is too heavy to compute (
where is number of parameters.). Has some nice approaches to visualize loss landscape as well. Paper link. Github
Iterative methods using second order information:
- Paper on retraining MLP (seems to use a similar idea) from 1992 by Bishop: https://www.microsoft.com/en-us/research/wp-content/uploads/1991/01/Bishop-Fast-Procedure-Retaining-multilayer-perception-IJNS91.pdf
- Newton-raphson method for iterative optimization: https://en.wikipedia.org/wiki/Newton's_method
- Sophia: A method to use second-order statistics for optimization. Found this while studying the PyHessian paper. Paper link
- Shampoo algorithm