Understanding the TFLite Quantization Approach (Part I)

Understanding the TFLite Quantization Approach (Part I)

Running machine learning models can be computationally expensive, and these models don’t typically port well to hardware or embedded devices. The TensorFlow Lite (TFL) library is, according to the documentation, “a mobile library for deploying models on mobile, microcontrollers and other edge devices.”

TFL allows the conversion of native TensorFlow models into smaller, more lightweight quantized models typically operating in reduced-precision, allowing for both a significant reduction in the cost of running the model, as well as the ability to port quantized models to hardware directly.

Due to these constraints, the quantization scheme used in the TFL library is substantially different from naive affine quantization approaches. This note provides an introduction to the ideas behind the library’s quantization scheme.

Representation of floating point scales

The overall goal of the quantization (not just in TFL, but in every library) is to represent floating point weights and activations as reduced-precision data, typically 8-bit integers. While training of these models typically happens in floating point to allow for gradients to be taken, quantization of a model post-training attempts to store all data as integers, requiring only integer operations on the quantized values. The result? A compressed model with very fast inference.

Throughout this article, we’ll adopt the following general notation: the symbol qq will refer to quantized values, rr will refer to real mathematical values (infinite precision), and ss will refer to scale values stored in floating point. We will use ZZ to denote an integer-valued “zero-point”.

We’d like to map a set of real values to integers in some range. Informally speaking, we’re looking for a linear mapping of integers a:NRa:\mathbb{N}\to\mathbb{R} such that

r=a(q)=s(qZ),r = a(q) = s(q - Z),

where s,  Zs, \; Z are parameters to be chosen. That is, for every integer qq in some range (for example [-128, 127] in the case of 8-bit integers), given some floating point scale ss and an integer zero-point ZZ, we can reconstruct an approximation to the original continuous value rr in floating point using the using the equation above.

More generally, the integer qq is quantized as a BB-bit integer. The number sRs\in\mathbb{R} is represented using a floating-point number. The integer ZZ is quantized the same way as qq and corresponds to the real value 0, which in most cases should always be exactly representable. To see why, take the example of convolutional neural networks. Convolutional deep learning implementations should be able to support padding arrays with zero values. These zero values must be exactly representable as zero—it would be disastrous from an accuracy perspective to introduce a numerical error by using a small nonzero value to approximate zero when performing zero-padding of the arrays during a convolution.

So with the 8-bit quantization approach, instead of storing every array value as a floating point number, we instead use only one floating point number ss to represent the scale, one integer ZZ to represent the zero point, and 8-bit integers to store the array values. As the bit-width of double precision floating point numbers is 64 bits, we’ve reduced our memory requirement by a factor of 8 with this simple linear transformation.

However, we’re still representing our scale parameter ss as a floating point number, and many hardware platforms and embedded systems don’t have floating point support. If we’d like to deploy our model on devices such as these, we have to find an alternative representation of the floating point parameter which is still accurate and efficient. The central idea of the TFLite quantization framework is representing the scale parameter value ss without using a floating point number by making the decomposition

s=M02n=M02n,s = \frac{M_0}{2^n} = M_0 2^{-n},

where M0M_0 is an integer of bit-width nn. In addition to no longer requiring an explicitly stored floating point number, this representation is advantageous because it reduces multiplication by the value ss to an integer multiply and a bitshift operation (i.e., division by a power of two). Using this decomposition, we regain the ability to compute approximate floating point multiplication without needing to represent, store, or multiply an actual floating point number.

Typically M0M_0 is a 32-bit integer (n=32n=32), and from here on, we’ll take M0M_0 to always have a 32-bit representation, without loss of generality. Straightforward algebra shows that the quantized multiplier M0M_0 can be computed as M0=int(s2n)M_0 = \text{int}\left(s\cdot2^n\right), but it may be helpful to solve for the multiplier in the following (equivalent) way

s=2nM0log2(s)=log2(M0)+log2(2n)M0=int(2log2(s)+n),s = 2^{-n}M_0 \\ \log_2(s) = \log_2(M_0) + \log_2(2^{-n}) \\ \Rightarrow M_0 = \text{int}\left(2^{\log_2(s) + n}\right),

which admits the interpretation that we’re increasing the binary order-of-magnitude of the scale parameter by nn. From this operation, we can see that the representation is, in theory, numerically “nicer” when the scale parameter to be quantized is small, i.e., has the property log2(s)<0\log_2(s) < 0, because M0M_0 can be represented with fewer bits. Similarly, if the scale parameter’s binary order-of-magnitude log2(s)<32\log_2(s) < -32, it will be represented as zero numerically.

Evaluating M0M_0 using either expression requires a choice of rounding scheme to perform the final round operation to an integer; once this is chosen, the decomposition algorithm is completely specified. In the TFL library specifically, if M0M_0 is represented as an int32_t, and is computed in the method QuantizeMultiplier (defined here), which defines a thin wrapper around std::frexp to quantize the scalar multiplier while avoiding 32-bit integer overflow. The implementation of std::frexp applies rounding by adding one to the least significant bit of the of the result if the truncated bits are more than half of the maximum value for that bitwidth. For reference, the details of the rounding operation can be seen in the 64 bit integer IntegerFrExp method here. This method is defined by pre-processor directive when floating point instructions are not available. With the integer-bitshift representation of floating point mutliplication as well as its quantization machinery in place, we can examine the algorithms used to perform quantization of common operations.

Quantized multiplication

For the case of element-wise multiplication of two input arrays, consider two numbers, A1A_1 and A2A_2, which are multiplied to form a resulting number RR, i.e., A1A2=RA_1 A_2 = R. The gist of the numerical quantization algorithm follows from the quantization of each input, as well as the output:

A1A2=s1(q1Z1)s2(q2Z2),R=sR(qRZR).A_1A_2 = s_1(q_1 - Z_1)s_2(q_2 - Z_2), \\ R = s_R(q_R - Z_R).

For each number A1,A2 and RA_1, A_2 \text{ and } R, the number qq represents the 8-bit quantizations of the floating point numbers, ss is the floating point representation of the scale parameter, and ZZ represents the 8-bit integer zero-point. Representing each floating point scale ss using the quantization scheme above, we have (since sM02ns \approx M_02^{-n}):

A1A2M0,12n1(q1Z1)M0,22n2(q2Z2)=M0,R2nR(qRZR)R.A_1A_2 \approx M_{0,1}2^{-n_1}(q_1 - Z_1) \cdot M_{0,2}2^{-n_2}(q_2 - Z_2) = M_{0,R}2^{-n_R}(q_R - Z_R) \approx R.

We can rearrange this to form the quantized outputs in terms of the inputs:

qR=ZR+M0,1M0,2M0,R2(n1+n2n3)(q1Z1)(q2Z2).q_R = Z_R + \frac{M_{0,1} M_{0,2}}{M_{0, R}}\cdot2^{-(n_1 + n_2 - n_3)}\cdot (q_1 - Z_1)(q_2 - Z_2).

However, it would be inefficient to compute the three quantized multipliers M0,1,M0,2M_{0,1}, \, M_{0,2} and M0,RM_{0,R} and to apply the three bitshifts separately. On top of introducing overflow complexities, doing so would require three multiply operations and bitshift operations (n1,n2, and n3n_1, n_2, \text{ and } n_3) instead of one. Instead, we can use the equivalent original arithmetic relation (before the power-of-two quantization) to write

qR=ZR+s1s2sR(q1Z1)(q2Z2)=ZR+M(q1Z1)(q2Z2),q_R = Z_R + \frac{s_1 s_2}{s_R}(q_1 -Z_1) (q_2 - Z_2) = Z_R + M^{\prime}(q_1-Z_1)(q_2-Z_2),

where we’ve made the definition M=s1s2/sRM^\prime = s_1s_2/s_R , and quantize the single “real multiplier” MM^\prime in one shot. That is, we compute MM^\prime in floating point and then quantize it. If we look at the source code of the Prepare method in the quantized multiplication source code, we see the following lines (comments are mine) which mirrors what we’ve derived above:

// Initialization, error checking...

// Quantization 
double real_multiplier = input1->params.scale * input2->params.scale / output->params.scale;
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &data->output_shift);
// ...

Quantized subtraction

When considering binary operations such as addition or subtraction, we can’t group factors as conveniently as with multiplication operations. For example, in the case of a quantized subtraction operation A1A2=RA_1 - A_2 = R (using the same notation as above):

A1A2=s1(q1Z1)s2(q2Z2)=sR(qRZR)=R.A_1 - A_2 = s_1(q_1 - Z_1) - s_2(q_2 - Z_2) = s_R(q_R - Z_R) = R.

This time, we can’t group all the scales into a single parameter due to the subtraction operation. Furthermore, as we’ve made no prior assumptions on the numbers A1,A2A_1, A_2, or RR, it’s straightforward to see how we could run into integer overflow if all the scales are naively represented using the power-of-two quantization approach above. Therefore, we’d like to scale the equation so that all the scale parameters are roughly of the same order of magnitude. We have that

s1+s22max(s1,s2),s_1 + s_2 \leq 2 \cdot \max(s_1, s_2),

and therefore it immediately follows that

0<s12max(s1,s2)12.0 < \frac{s_1}{2\cdot \max(s_1, s_2)} \leq \frac{1}{2}.

Therefore, we can define the scalar α=2max(s1,s2)\alpha = 2\cdot \max(s_1, s_2), and rescale the original quantized subtraction equation as

s1α(q1Z1)s2α(q2Z2)=sRα(qRZR),\frac{s_1}{\alpha}(q_1 - Z_1) - \frac{s_2}{\alpha}(q_2 - Z_2) = \frac{s_R}{\alpha}(q_R - Z_R),

which now has the property that the two scale parameters on the left-hand side are in the interval (0,0.5](0,0.5]. It also stands to reason that if the rescaled left-hand side scales are both on the order of 11, then their subtraction scale will also be of similar order (do you see why?)*[footnote: The q,Zq, Z parameters are bounded by the 8-bit integer range -128, 127; the scaled multipliers are bound between 0 and 1/2. Can you write down a bound on the output scale?]. Therefore, straightforward rearrangement gives

qR=ZR+αsR[s1α(q1Z1)s2α(q2Z2)]=ZR+MR[M1(q1Z1)M2(q2Z2)],q_R = Z_R + \frac{\alpha}{s_R}\left[\frac{s_1}{\alpha}(q_1 - Z_1) - \frac{s_2}{\alpha}(q_2 - Z_2)\right] \\= Z_R + M_R\left[M_1(q_1 - Z_1) - M_2(q_2 - Z_2)\right],

and now we must use the three quantized multipliers M1,M2,M_1, M_2, and MRM_R to compute the quantized output qRq_R. In the TFL source code, this pre-scaling is done in the method PrepareGeneralSubOp (definition here). You’ll find the analogous manipulation in the following (comments are mine):

// tensorflow/tensorflow/lite/kernels/sub.cc
// ...

// compute alpha
const double twice_max_input_scale 
       = 2 * std::max(input1_quantization_params.scale, 
                      input2_quantization_params.scale);

// compute the scaled multipliers
const double real_input1_multiplier = input1_quantization_params.scale / twice_max_input_scale;
const double real_input2_multiplier = input2_quantization_params.scale / twice_max_input_scale;'

const double real_output_multiplier =
  twice_max_input_scale / ((1 << op_params->left_shift) * output_quantization_params.scale);

With the reasoning above, the source code becomes simple to follow.

Extensions

The discussion and derivations above aim to provide a working understanding of how all quantized operations occur in the TFL library. Other operations follow from similar logic. For example, the quantization of matrix multiplication involves both sums and multiplications, but is derivable in a straightforward way using similar reasoning as above. If you’re interested in testing your understanding of the TFL quantization scheme, see if you can derive it—the answer, as well as further discussion, is given in the TFL quantization paper.

There are many additional subtleties behind the implementations in the TFL library. We’ll discuss them, as well as the original paper, in future posts.