The Math of AWQ: Protecting Salient Channels from the Inside Out
When we compress a Large Language Model (LLM) from 16-bit floating-point numbers to 4-bit integers, we save a massive amount of memory. But compression creates mathematical damage. We are taking an infinite spectrum of continuous numbers and forcing them onto a rigid grid of integers.
Activation-aware Weight Quantization (AWQ) is a brilliantly simple method to mitigate this damage. To understand why it works, we must build the math from the ground up.
Step 1: The Anatomy of a Quantized Weight
The core of quantization is mapping a floating-point weight (\(W\)) to an integer code, and then mapping it back. We define this entire round-trip process using a step size (\(\Delta\)):
\[Q(W) = \Delta \cdot \text{Round}\left(\frac{W}{\Delta}\right)\]This elegant equation actually represents both halves of the process:
- Quantization: We divide the weight by \(\Delta\) to find out how many “steps” it represents, and
Round()it to snap it to the nearest whole integer. - Dequantization: We multiply that integer code back by \(\Delta\) to stretch it back into a floating-point approximation that the GPU can use for math.
But how do we calculate \(\Delta\) in the first place?
Step 2: Defining the Step Size (\(\Delta\))
In symmetric quantization, we want to center our integer grid perfectly around zero.
First, we look at our tensor of weights and find the largest absolute value. We call this maximum parameter \(a\):
\[a = \max(|W|)\]This defines the floating-point range we need to capture: \([-a, +a]\). We must map this continuous interval onto our available discrete integer bins. For a bit-width of \(b\) (like INT4 or INT8), the signed integer range is:
\[[-(2^{b-1}-1), +(2^{b-1}-1)]\]For INT4 (\(b=4\)), this gives us an integer range of \([-7, +7]\). (Note: While standard INT4 can hold -8, we use a restricted range of -7 to +7 to keep the symmetry perfectly balanced around zero).
Because we are mapping the total float distance (\(2a\)) to the total integer distance (\(14\)), the formula for our step size is:
\[\Delta = \frac{a - (-a)}{(2^{b-1}-1) - (-(2^{b-1}-1))} = \frac{2a}{2(2^{b-1}-1)} = \frac{a}{2^{b-1}-1}\]The Intuition Example: Imagine a group of weights where the maximum absolute value is 10 (\(a = 10\)). For our INT4 grid, \(\Delta = 10 / 7 \approx\) 1.43. \(\Delta\) is the slope of the line connecting our codes to reality. It means that every time our integer code moves by 1, the real floating-point value jumps by 1.43.
Step 3: Deriving the Output Error
A neural network layer works by multiplying an input activation (\(x\)) by a weight (\(W\)). The perfect output is \(Wx\). The compressed output is \(Q(W)x\).
To see the damage caused by compression, we calculate the Output Error (\(\text{Err}\)):
\[\text{Err}(Q(W)x) = Q(W)x - Wx\] \[\text{Err}(Q(W)x) = (Q(W) - W)x\]Let’s look at \((Q(W) - W)\). What is the difference between a quantized weight and a true weight? It is simply the fractional rounding error multiplied by our step size! We can formalize this as:
\[(Q(W) - W) = \Delta \cdot \text{RoundErr}\left(\frac{W}{\Delta}\right)\]Substitute this back into our equation, and we arrive at the foundational formula of the AWQ paper:
\[\text{Err}(Q(W)x) = \Delta \cdot \text{RoundErr}\left(\frac{W}{\Delta}\right) \cdot x\]Step 4: The Statistical Reality of Rounding
Look closely at the error equation. It has three parts: \(\Delta\), the Rounding Error, and the activation \(x\).
The paper makes a crucial statistical observation about the RoundErr term. When you divide thousands of random weights by \(\Delta\), their fractional parts are spread evenly between 0 and 1. Therefore, when you round them to the nearest integer, the errors are uniformly distributed between -0.5 and +0.5.
If we only care about the magnitude of the error, the average absolute rounding error is always roughly 0.25 steps. In modern hardware, we use group-wise quantization, meaning 128 weights are forced to share the exact same \(\Delta\).
- If \(\Delta\) is fixed for the group…
- And
RoundErraverages out to a fixed 0.25…
Then the quantization error in the output is strictly proportional to the input activation (\(x\)).
If the activation flowing into a weight is small, the error is tiny. But if the network relies on a massive, “salient” activation channel, that large \(x\) multiplies with the grid misalignment and produces a catastrophic discrepancy in the final output.
Step 5: The AWQ Solution: Scaling Down Activations
We are trapped by algebra. We cannot shrink \(\Delta\) or stop the rounding error. To save the network, we must mathematically reduce the input activation magnitude (\(x\)).
We do this by introducing a scaling factor, \(s\). We scale down the dangerous activation by dividing it:
\[\text{New Activation} = \frac{x}{s}\]But to keep the layer’s overall computation mathematically identical (re-parameterization), we must multiply the weight by that exact same \(s\):
\[\text{New Weight} = W \cdot s\]In perfect floating-point math, \((W \cdot s)(x / s) = Wx\). The network’s intelligence remains untouched. But the magic happens during quantised inference, because the rounding now acts on the scaled weights.
Step 6: The “Delta Dash” Trap and the Tug of War
If dividing the activation by \(s\) is good, why not use a massive scale like \(s=100\) and shrink the error to zero?
Because of the group \(\Delta\). Remember, \(\Delta\) is defined by the maximum weight in the group. If we take a salient weight (\(W\)) and multiply it by a scale \(s_i > 1\), that weight physically grows. If we scale it so much that it becomes the new maximum for the group, it inflates the step size.
The paper defines this new inflated step size as \(\Delta'\) (Delta Dash). Our output error equation transforms into a brutal tug of war:
\[\text{Err}\left(Q(Ws)\left(\frac{x}{s}\right)\right) = \Delta' \cdot \text{RoundErr}\left(\frac{Ws}{\Delta'}\right) \cdot \left(\frac{x}{s}\right)\]The Trade-off in Action:
- The Good: By shrinking the activation multiplier to \((x/s)\), we drastically reduce the effective contribution of quantization error from the channels where it matters most.
- The Bad: If \(s\) is too large, the scaled weight \(Ws\) pushes \(\Delta'\) higher (\(\Delta' > \Delta\)). A wider step size ruins the grid precision for every other normal weight in that group.
We must find the perfect scale \(s\) that shrinks the activation, but stops just before it inflates \(\Delta'\).
Step 7: Optimizing the Scaling Factors (\(\alpha\))
How do we find this perfect \(s\)? In traditional deep learning, we would train the neural network using gradients and backpropagation.
But AWQ avoids this. Because the error relies on a Round() function, which looks like a flat staircase on a graph, its mathematical derivative is exactly zero. Gradients cannot flow through step functions reliably, making gradient-based optimization wildly unstable and incredibly memory-hungry.
Instead, the authors use pure logic. We know a weight channel only needs to be scaled up if the activations flowing through it are large. Therefore, the scale \(s\) should just be a mathematical reflection of the average activation magnitude (\(s_X\)).
They define the scale using a single dial, \(\alpha\):
\[s = s_X^\alpha\]Rather than training, they perform a lightning-fast Grid Search over \(\alpha \in [0, 1]\):
- When \(\alpha = 0\): There is effectively no scaling (\(s=1\)). We get naive quantization, and loud activations destroy the output.
- When \(\alpha = 1\): We get the most aggressive scaling possible (\(s=s_X\)). We fully protect the salient channel, but we risk inflating \(\Delta'\) and ruining the rest of the group.
By testing a few points between 0 and 1, AWQ elegantly finds the \(\alpha\) that optimally balances the tug of war: scaling salient channels up just enough to land on cleaner grid points, shrinking the dangerous activations, and keeping the group’s \(\Delta\) safe.
References
- Hamza Elshafie, AWQ: Activation-aware Weight Quantisation
- Lin et al., AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration (original paper)