Second order optimization methods

One line of thinking that's gained popularity in recent times is using second-order information from the loss to help prune or quantize the model. An example is hessian-based mixed precision quantization (the paper: HAWQ). The idea here was to use the hessian to tell us which layers contributed the most to maintaining low loss, and these layers in turn were deemed more 'valuable' and kept in higher precision. The reason for the hessian was it could tell us about the curvature of the loss curve at the optima along different axes - if there was a steep upward curve with respect to a certain layer's weights, any perturbation to those weights could cause a large jump in the loss. The computation was efficiently done by using the top eigenvalues of the hessian rather than computing the entire hessian for all parameters (memory of O(n2), compute of O(n3) The hessian is also used in the popular GPTQ method, which in turns takes its ideas from the method of the Optimal Brain Surgeon (OBS)

This approach made me consider the idea of using second-order optimization methods in training. One idea that came to mind was as follows: take the hessian at a certain point in the parameter-loss space to estimate convexity. Let's say such a point exists (ie, hessian is positive semi-definite). Then, use a quadratic approximation to find an optima to jump to, and calculate the loss again at the new point. There are three main questions here:

  1. How to efficiently can we calculate the hessian?
  2. Are locally convex points easy to find?
  3. Is this method actually more efficient?

I did some literature review on this and came up with the following:

For computing the hessian:

Iterative methods using second order information: