1. Introduction to Quantization
Background
I was first introduced to quantization after grad school, while working at a ML chip accelerator company. My work there centered around a challenge that many modern ML models face - they're very expensive to run. On a local device like a car or a phone with limited compute capabilities, running a large computer vision or language model can be very slow, consume too much energy, or simply not fit in the available memory. Making API calls to a cloud-based models has its own problems - high cost, latency, and lack of data privacy or model customization.
Generally, we were given a pre-trained model from a customer company or client, with the goal to make it run on our custom chip. The chip could run models at a high efficiency[1] using some smart engineering to design an efficient primary processor, instruction scheduler, and memory transfers. However, to achieve more efficient processing, we needed to change the model a little - instead of maintaining the parameters and activations in high precision (typically 16-32 bit floating points), they had to be compressed to 8-bit integers. This would accelerate memory transfers and matrix multiplications[2], letting us hit higher efficiency numbers and stay competitive.
So far so good, right? Here's the kicker - when we reduced the bit precision of our model values, they would often lose a huge amount of performance on the tasks they were trained for. Sometimes the accuracy[3] drop would be a few percentage points, but this drop could be as large as 50% or more, indicating the model is completely losing all the knowledge it learned in training. Worse still, although some models were known to be more susceptible, these variations were highly unpredictable.
Luckily, we weren't the only ones to be facing this issue. Everyone from Nvidia to Google had run against this wall, coming up with new methods and training strategies to help. Researchers also noticed some underlying issues that could be mitigated somewhat cheaply given a pre-trained model in full precision. The umbrella term to encapsulate these approaches came to be known as 'post-training quantization' (PTQ). I got to test a lot of these methods, and as I did I grew increasingly curious about the general problem of quantization. The problem seems to give us insights about the loss landscape, model training dynamics, and the central problem in ML - how can we learn computationally efficient representations of a set of data and task[4].
To be clear, PTQ is not the only approach to quantization. Exciting new techniques are emerging to focus on low-precision model training (FP8 Training), intermediate solutions like quantization-aware training (QAT) and quantized fine-tuning (QLoRa). I hope to use this deep dive to explore current state-of-the-art approaches, ass well as explanations for what seems to be happening in the quantized regime. I'll also be tackling problems that give a lay of the land, even if they're not directly related.
The general problem of model compression and optimization has led to an explosion of fascinating work outside of quantization. These range from optimizing the methods of memory storage and transfer for activations (eg, flash-attention), building faster kernels that can perform the linear algebra of ML more efficiently, distilling and pruning models, to developing of entirely new models based on classical signal processing theory (SSMs). Broadly, these fall into hardware and software optimizations (I clump model changes into the software category). We’ll be exploring many of these ideas, with a bias towards model-centered techniques. Part of the reason for this is practical - working at the hardware level is fairly difficult to do without a good amount of prior setup, hardware navigational skill, and actually reliable hardware to tinker with. The other is that I’m simply more interested in the work on the ML side right now, although I might attempt to explore hardware in the future!
Basics of Quantization
Here we provide a brief definition of quantization. Simply put, it's the process of discretizing a continuous (aka higher precision) spectrum of values using a chosen scheme. This process bins values into designated groupings, compressing the number of representations (aka bytes) that are used to store all the data. A simple quantization scheme is roundToNearestInteger(float val)
, that rounds any real valued number to the closest integer. In this approach, the values 3.05, 3.003, and 3.49 all round to 3. As you can see, we tend to lose precise information about our values for the tradeoff of less memory consumed. Refer to Visualizing Quantization as a great resource to visualize how quantization works.
The first topic we'll work on will rely on fixed point (aka integer) quantization, so we detail this approach below.
Fixed-Point Quantization
The process for fixed-point/integer quantization is as follows: Suppose we have a group of values (eg, a model weight tensor). If we wish to compress values to
Without going into too much detail, matmul.
We don't have to stick to one tensor as our group. If we wanted we could take all of our model's weights and activations[6], calculate a bin size and quantize all the values into these bins. However, numerical values across the model can range wildly across multiple orders of magnitude, so this approach will result in a huge quantization error. On the other hand, in theory we could 'quantize' each and every term individually with individual quantization parameters
Clearly, we need to find a balance with the level of granularity of our groups. Through experiment and a study of hardware operator flow, engineers have found that the optimal granularity is different for weights and activations. For activations, the best balance is to make each intermediate activation tensor into a group, and quantize appropriately. For weights, however, we can get away with a more granular approach, and perform per-channel/neuron quantization. This is best explained visually, as can be seen in Figure 1.
Figure 1. A cross section for a convolutional neural network. Each of the blocks (B1, B2, etc.) are intermediate activations, where one block corresponds to one tensor. Each strip within a block is one channel (eg there are 64 channels in B1). Although not shown here, the weights that propagate the activations through the network have a corresponding set of channels, corresponding to the output activation channels (eg, between B1 and B2 there are 128 weight tensors, one for each output channel of B2).
Even for a given group (channel/tensor), we don't have to use the min and max values to quantify our range. There are more sophisticated methods that can clip outlier values, or simpler methods that use
Footnotes
The key unit being FLOPS/second/watt. FLOPS/second represents the rate at which the chip could execute floating point operations. This lets us measure how quickly the chip could perform matrix multiplications, the cornerstone for any deep learning model. FLOPS/second/watt tells us how efficiently the chip could run these operations. ↩︎
The basic processing block for a matrix multiplication is a series of multiplications and sums (ie, for
, multiply every element of row 1 in A by column 1 in B, and add values together to get element ). In hardware, this can be represented by a series of MAC operations: Multiply-Accumulators. One MAC will multiply two numbers and add a third. For floating-point numbers, each MAC takes two FLOPS - one to multiply, one to add. ↩︎ Although there are many metrics for how well models do beyond accuracy, we use the term as a general placeholder for the output quality of the model, as the term 'model performance' an overloaded term, and is often defined as how efficiently the model can run. ↩︎
As a brief aside, many incredible scientists have approached these questions from various angles, including (1) Traditional ML learning theory and PAC-completeness, (2) Neural Tangent Kernel theory (NTK), (3) Group-theoretic models of semi/un-supervised learning and data augmentation, (4) Various fascinating empirical benchmarks to test the core ideas of the task being learned (long-range arena, cogs, arithmetic, needle in a haystack), (5) Mechanistic Interpretability, or a functional circuit theory of a model. All of these approaches probe how larger models work in the modern era. I hope to elucidate and employ these ideas in upcoming writings, and am writing a list here for reference ↩︎
The zero-point is a key term that eliminates the quantization error from quantizing 0. This is critical as some operations perform zero-padding, so any errors here would propagate uncontrollably. ↩︎
We typically use a small calibration dataset to get an idea of what activations look like in test-time. Selection of this dataset can be extremely important as activations often contain outliers that are difficult to quantize. ↩︎