How to quantize a neural network model in PyTorch: an in-depth explanation

Oscar Wiljam Savolainen, PhD
25 min readJan 29, 2024

--

For a background on why neural network quantization is attractive, I would recommend this Medium post: https://medium.com/sharechat-techbyte/neural-network-compression-using-quantization-328d22e8855d. For a comprehensive background on what it means to quantize a floating-point tensor, I would recommend a video I made: https://www.youtube.com/watch?v=rzMs-wKQU_U . I would also generally recommend my channel on AI, where I have a playlist where I go over all things quantization: https://www.youtube.com/@OscarSavolainen. E.g. I have a video where I quantize a ResNet entirely from scratch in Eager Mode: https://www.youtube.com/watch?v=jNZ1rkIfwsM&feature=youtu.be. I would also very strongly recommend the highly respected white paper by Qualcomm, which is a fantastic resource on int8: https://arxiv.org/abs/2106.08295.

For the rest of this blog, I will assume you already know that you want to quantize a neural network, are basically familiar with the background, and are looking for a tutorial on how to do it in code, specifically in PyTorch. If so, welcome to the right place! In this post, we will go over how to do it in code! :)

Specifically, we will go over Eager Mode quantization. PyTorch also has newer quantization modes called FX Graph mode and Export mode, which are very similar to each other. They involve working on the graph (i.e. on a GraphModule object), and so I will have another post about them as they work a bit differently and require some explanation of graph navigation and manipulation. For now, eager mode!

PyTorch

Where is the quantization code in PyTorch?

The first thing we need to know is where the quantization code in PyTorch is located, and how it is structured. For PyTorch, almost everything you need for quantization, at least on the Python-side, will be located in https://github.com/pytorch/pytorch/tree/main/torch/ao.

Like most of PyTorch code, the quantization code is accelerated with C++ ATen kernels. To really understand on a coding level what PyTorch is doing in terms of quantization, this can involve having to dig through the ATen side of PyTorch, e.g. https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp.

Throughout this post, I will assume we are doing int8 quantization (and not int4, int2, etc.), and so you will see reference to quantization grids with 2**8 = 256 values. However, the discussion generalizes to other quantization resolutions. I will also assume uniform quantization, since that is what is typically supported by hardware.

Once we know where the quantization code is, we can get started!

There are some 6 steps to quantizing a model in PyTorch. These are:

  1. Required Architecture Changes
  2. Fusing modules
  3. Attaching QConfigs
  4. Preparing the model for quantization, i.e. fake-quant
  5. Converting the model to a “true” int8 model
  6. Unit Testing

Step 1: Required Architecture Changes

Unfortunately, it’s often not possible to just take an arbitrary floating-point model and quantize it without making some architecture changes. There are 2 classes of changes that we will need to make to just get the model to quantize: adding Quant/DeQuant Stubs, and replacing some operations with FloatFunctionals. This is without going into different changes that will give the model better quantization performance, but that is an area of both immense pubic and private research which I will not go into here.

Quant/DeQuant Stubs:

These are quite intuitive. Quant and DeQuant stubs act as the input/output gates to the quantized parts of your model. A QuantStub takes a floating-point tensor, and converts it to a quantized tensor. A DeQuant Stub converts a quantized tensor to a floating-point tensor. I think the QuantWrapper class gives a very good visual demonstration as to what these stubs do and how they are used. If we simplify it a bit, it looks like this:

class QuantWrapper(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()

def forward(self, x):
# In this case, we use a `_q` suffix to denote a quantized tensor
x_q = self.quant(x) # quantizes the incoming floating point tensor
y_q = self.module(x_q) # runs the model with a quantized input
y = self.dequant(y_q) # dequantizes the output back to floating point
return y

You might be wondering, from the example above, is that enough to quantize the model? Can we just feed it a quantized tensor, and we’re done? Unfortunately, no. We will have to do a bunch of stuff to the model to make it quantize, and we’ll cover those later on in this blog post!

Something to be aware of is that QuantStubs are stateful: they come with quantization parameters (i.e. qparams) that will define how the incoming floating-point tensor is quantized. However, DeQuantStubs are stateless: they merely take a quantized tensor, and cast it to a floating point tensor. As such, if one is dequantizing multiple tensors one could use the same self.dequant instance without any problems. Similarly, there is nothing inherently wrong with using the same QuantStub for 2 tensors that you want to quantize in exactly the same way, e.g. if you are quantizing a pair of image tensors where you know exactly what the quantization grid should be, and are happy to use the same QuantStub. However, generally speaking, different floating-point tensors that are to be quantized should use different QuantStubs.

I would generally recommend against using the QuantWrapper class. Especially when you’re starting out with quantization, I would recommend placing the Quant/DeQuant stubs manually. This will give you experience with placing them. Furthermore, you’ll know exactly where they are in the model state-dict and forward calls, and as a result things will be less likely to go wrong. However, if you merely have to wrap your entire model in Quant/DeQuant stubs, then there would be nothing wrong with using the QuantWrapper class, even for a beginner.

FloatFunctionals:

FloatFunctionals are quite interesting. They exist because, unfortunately, some standard operations ( *, +, -, /and concat) do not work for quantized tensors. But why don’t they work?

Let’s say we have two quantized tensors, A and B. We would like to add them together to get C, where A 𝜖 [0, 1] and B 𝜖 [0, 1]. The question is, what should the qparams of C be?

One might say that C 𝜖 [0, 2], and so the quantization grid should be {0, 2/255, 4/255, … , 2*255/255}. However, it may be that, because of the relative distributions of A and B in our application, that in practice C 𝜖 [0, 1.1]. If that’s the case, then there would be no point wasting quantization space on [1.1, 2]. As such, for int8 quantization, with no clamping error, our scale qparam would only be 1.1/255 instead of 2/255, giving much better quantization resolution.

I’ll talk about how to obtain good quantization parameters later, but for now we just have to understand that certain very basic operations (e.g. adding, multiplying or concating two tensors together) require a requantization step with its own qparams. That is what FloatFunctionals (FF) do. They’re basically just a special class that provides different methods for simple operations, that come with a requantization step.

To make your model quantizable, you will have to replace your *, +, -, /and concat operations with these FFs. FF are a little bit tricker to get used to than QuantStubs, but once you get the basic idea they will also be easy! The way FFs are implemented in Torch is as a FF class, with different methods for all of the different operators. As of January 2024, there are 6 natively supported FF methods:

  • add
  • cat
  • mul
  • add_relu
  • add_scalar
  • mul_scalar

A very simple example of doing a subtraction with FFs looks like this:

class SubWrapper(nn.Module):
def __init__(self):
super().__init__()
self.neg = torch.nn.quantized.FloatFunctional()
self.addition = torch.nn.quantized.FloatFunctional()

def forward(self, x, y):
neg_y = self.neg.mul_scalar(y, -1) # We use the `mul_scalar` method to multiply `y` by `-1`.
out = self.addition.add(x, neg_y) # We use the `add` method to add `x` and `-y`.
return out

We can see that we initialize two FFs as part of the model. We give them the names neg and addition, and then call the appropriate methods to subtract y from x. If we were doing multiplication of x * (-y), we could call the method mul instead of add.

There is some subtlety when it comes to the add_scalar and mul_scalar methods: in PyTorch, these operations don’t have their “own” qparams. This is because the qparams for the outputs of add_scalar and mul_scalar operators can be perfectly derived from those of the scalar in question and the incoming tensor. From the above example, the qparams for neg_y can be perfectly derived from those of y and the -1 scalar. As such, under the hood, PyTorch ignores the qparams of the FF when doing add_scalar and mul_scalar, and uses the incoming tensor and scalar to derive the qparams in the ATen kernel. This may not be relevant to your application, but if the “correctness” of the state-dict is key to your application, then be aware that the qparams in the state-dict will not actually be “correct”, i.e. they will be ignored by the PyTorch code, and won’t be updated via gradient descent (in QAT, discussed later) or observation (in PTQ, also discussed later) in the case of add_scalar and mul_scalar FFs.

There currently (Jan 2023) isn’t a FF for quantized division in PyTorch, as the kernel hasn’t been implemented. A simple workaround is to dequantize prior to doing a division, and then requantizing the output of the division. Many hardware, e.g. Intel and Qualcomm, do support quantized division, but PyTorch has currently not implemented it.

In summary for FFs, you will need to go through your model definition and replace all of the relevant operators with FFs. These are stateful and will have qparams that should be customized to the operation in question, so one should generally initialize a unique FF for each operation that one does. Those are all of the architecture changes we have to make: Quant/DeQuant stubs, and FloatFunctionals!

Step 2: Fusing Modules

“When properly fused, the single being created has an astounding level of power, far beyond what either fusees would have had individually.” — Dragon Ball Wiki

Fusing modules is the next step. Is it optional, but highly recommended. Fusing merely consists of replacing certain adjacent layers/modules in your model with fused equivalents. For example, if one has a Conv layer followed by a ReLU non-linearity, it would be a shame in quantization to be forced to do the Conv operation, requantize the output, then perform a ReLU, and then requantize the output again. The qparams of the Conv layer should perfectly match those of the ReLU. Therefore, rather than waste computation with 2 requantization steps, add unnecessary data movement that slows down your model, and introduce the risk of unnecessary quantization error, one could fuse the ReLU into the Conv operation. In a fused ConvReLU, as one writes the values outputted by the Conv to a buffer, one clamps the negative values to zero. It is much faster than doing the operations separately, and eliminates any risk of extra quantization error!

There is also another reason to fuse modules: Post-Training Quantization (PTQ). I will discuss PTQ more later, but for now let’s make a note that we will circle back to another reason why fusing is very important later.

There are 2 methods to fuse modules: a name-based, Eager Mode method, and a graph-based method. In the name-based method, one manually feeds in a list of lists, with each sublist containing the names of the modules that should be fused together. Eager Mode quantization currently natively only supports the fusing of a small handful of module combinations:

  • conv, bn
  • conv, bn, relu
  • conv, relu
  • linear, bn
  • linear, relu

An example use of the Eager Mode, named-based function is given below:

torch.quantization.fuse_modules(
model,
[
['conv1', 'relu1'],
['conv2', 'relu2'],
['submodule_A.conv1', 'submodule_A.relu1']
],
inplace=True
)

I have a video coding tutorial available on how to quantize a ResNet model entirely from scratch. However, as I was preparing that tutorial I ran into a problem when fusing the modules. The problem was that because ReLUs are stateless, many coders will re-use the same ReLU instance, or use a functional API, for their ReLUs within a module. Combined with fusing, this can cause hard-to-debug errors. For example, let’s say one fuses the following module:

import torch
class Block(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 5, 3)
self.conv2 = torch.nn.Conv2d(5, 5, 3)
self.relu = torch.nn.ReLU()

def forward(self, x):
x1 = self.conv1(x)
x1 = self.relu(x1)
x2 = self.conv2(x)
x3 = x + x2
output = self.relu(x3)
return output

block = Block()

torch.ao.quantization.fuse_modules(
block,
[
['conv1', 'relu'],
],
inplace=True
)

We will see the following:

Block(
(conv1): ConvReLU2d(
(0): Conv2d(3, 5, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
)
(conv2): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1))
(relu): Identity()
)

The ReLU gets fused into the first Conv, and the ReLU then gets set to an Identity. However, the problem is that the ReLU, which has now been set to Identity, is still used later on in the forward call to transform x3. This fusing is tantamount to removing the second ReLU operation, which changes the very nature of your model and which is very undesired! As such, I would highly recommend:

  1. Initializing ReLUs as parts of the state-dict and not using functional APIs (e.g. do not use torch.nn.functional.relu);
  2. Initializing each ReLU that is used separately;
  3. Giving a name to each ReLU, according to some naming convention, so that it can be easily and uniquely fused to its desired Conv, or not fused, as the case demands.

I would do the same for all fusible layers, and generally use a naming convention that makes it easy to fuse the desired modules.

This named-based, Eager Mode method is inherently fragile: if someone changes the name of any relevant modules, then the fusing will silently fail, which is quite bad. There are 2 solutions: strict enforcement of naming conventions inside the model (e.g. with unit testing to ensure that every to-fuse module only exists in a single sublist, or maybe even that there shouldn’t be any non-fused activations in one’s model if possible), or using the graph-based method.

The graph-based method symbolically traces your model and fuses whatever modules that can be fused. It is part of the FX Graph / Export Mode Quantization. I would highly, highly recommend using it if you can (and I will have a separate post on that in the upcoming weeks). It is a much more robust fusing scheme than the Eager Mode, named-based method. FX Graph Mode also supports more features. E.g., in Eager mode, I have sometimes had to fuse a Conv and “advanced” activations. I managed to manually hack it, but it would have been way easier if I was working in FX Graph mode and could rely on the suite of built-in features, e.g. fused ConvLeakyReLU and/or Sigmoid. However, this involves working with the graph directly, and so in my future post on that I will explain the graph a bit more and how it works.

For fusing BatchNorms, there are 2 options. One can have either a quantized Conv-BatchNorm (or Conv-BatchNorm-ReLU) that has static BatchNorm statistics, or one that updates them during training. In inference, one should generally use static BatchNorm statistics that can be fused into the Conv with the BatchNorm set to Identity, but during Quantized Aware Training (QAT), which we will discuss later, one may wish to update the BatchNorm stats.

To have a static BatchNorm folded into the Conv, one would use torch.quantization.fuse_modules. To have fused quantized BatchNorm layers that will update the statistics during training, one would use torch.quantization.fuse_modules_qat .

When porting models to hardware, one is not limited to fusing the modules that PyTorch allows you to fuse. For example, on Intel via OpenVINO, one could fuse a Conv and a PReLU. This is not natively supported by PyTorch, but one could write one’s own custom implementation. Fused modules are fundamentally just Sequentials: on the Python level, fused modules are sequential wrappers, on the C++ Aten level, it’s a custom kernel that performs the fused operations prior to the requantization step.

That is it for fusing! It’s an important and highly recommended (but technically optional) step when quantizing one’s models!

Step 3: Assigning QConfigs

The third step in quantizing a model is to tell each layer how it should quantize. E.g., should it do symmetric or affine quantization? Should the weights be quantized on a per-tensor or per-channel granularity, and what about the activation (i.e. output feature)? Should the qparams be updated via observation only, or via gradient descent, or should they be fixed at initialization? What kind of PTQ observer should be used, should it be a MinMax observer, or a Histogram observer that minimizes float-to-quant L2 distance, or something else? What should the initial qparams be? These instructions are all given via the QConfigs, which are assigned to each layer as attributes.

QConfigs are quite simple, but they are key parts of your quantization. Here’s an example, where we create and then assign a qconfig to an arbitrary module.

import torch
import torch.ao.quantization as tq
from torch.ao.quantization._learnable_fake_quantize import _LearnableFakeQuantize

activation_qconfig = _LearnableFakeQuantize.with_args(
observer=tq.MinMaxObserver, # PTQ observer
quant_min=0,
quant_max=255,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine, # We specify we want per-tensor affine quantization
scale=0.1, # Initial qparam scale
zero_point=0.0, # Initial qparam zero-point
use_grad_scaling=True,
)

module = torch.nn.Module()

module.qconfig = tq.QConfig(
activation=activation_qconfig,
weight=tq.default_observer.with_args(dtype=torch.qint8)
)

The weight qconfig we assign above is just the default one given by PyTorch. I do not recommend that one use that one in your own models, it is just for the case of this example. One should always ensure that one’s qconfigs are as close to ideal as possible, and some experimentation may be necessary.

The activation qconfig is a _LearnableFakeQuantize qconfig that can train its qparams via gradient descent. We are using a MinMax PTQ observer, with per-tensor quantization of the activation, and set use_grad_scaling=True. Grad scaling, which is very popular, involves dividing the gradient based on the size of the tensor in question, and was introduced in Esser et al. (2019). Grad scaling was designed with SGD optimizers in mind, and is common practice in QAT. An intuition as to why it works is that the qparams are very influential parameters in your model: updating them slowly to keep the quantization from being extremely chaotic makes a certain amount of sense. If using SGD or some optimizer that assigns the same LR to different parameters in your model, then you’ll want to add grad_scaling to keep the qparams from getting updated too significantly, as they are extremely influential in your model.

QConfigs should be assigned to each module that we want to quantize. If you don’t want to quantize a module, you can assign None as its qconfig, and that will keep the module dequantized. However, in that case it would be correct to also place Qaunt/DeQuant stubs correctly around the dequantized parts of your model.

Assigning qconfigs is quite simple, but the complexity comes from choosing which qconfigs to assign, and in creating one’s own custom qconfigs and/or PTQ observers. torch.ao.quantization.observer.py contains the native PyTorch observers, and torch.ao.quantization.fake_quantize.py and torch.ao.quantization._learnable_fake_quantize.py contain the native fake-quantize objects.

Rules of thumb for selecting qconfigs:

A good rule of thumb is that you want per-channel, symmetric quantization for weights, and per-tensor, affine quantization for activations. Per-channel quantization of weights is tolerated because the weight tensors can be loaded into Multiply-And-Accumulate (MAC) operators individually for each channel, and so there isn’t significant overhead for having per-channel quantization. Symmetric quantization is highly recommended for weights, because the zero-point associated with affine quantization can cause significant slowdowns because of the need for data-dependent terms that can’t be pre-calculated (see NVIDIA’s white paper for an in-depth explanation). Weights also tend to be symmetrically distributed around 0, so symmetric quantization is generally not a significant burden for weight tensors.

Activations can typically only afford per-tensor quantization, because per-channel quantization of activations is generally not supported by hardware. This is a bit tricky to explain (and for a more detailed mathematical explanation see the NVIDIA white paper, section 3.2), but it has to do with what values, specifically what intermediate scaling factors inside the MAC working on quantized tensors, can be pre-computed ahead of time. In practice, the activation data that is fed to the MAC may belong to different batch instances (and so any pre-computed scaling factor would not be meaningful, because it would only be correct for a given batch instance), and calculating the scaling factor on-the-fly would incur significant computational overhead. tl;dr: hardware for Neural Networks (e.g. MAC operators) have been extremely optimized, and so generally do not support per-channel quantization of activations.

Affine quantization is generally desired for activations, because some activations, e.g. those outputted by a ConvReLU, are positive. It would be a shame to spend quantization range on negative values, as one does in symmetric quantization, if all of the values in the tensor are positive or zero. That is why affine quantization is generally recommended for activations. However, the only exception to this that I know of is for targeting NVIDIA GPUs: their TensorRT quantization framework requires one to do symmetric quantization of both weights and activations.

Finally, it is generally best to use learnable qparams, unless one has a tensor where one knows exactly what the qparams should be. For example, if one is quantizing an RGB image tensor, then one should always use affine quantization with fixed qparams with a zero-point of 0 and a scale of 1/255. This ensures perfect quantization, as RBG images natively live on a {0, 1/255, 2/255, … , 255/255} grid anyway.

Understanding PTQ and QAT

Before we move on, we should take some time to understand PTQ and QAT.

Post Training Quantization (PTQ):

PTQ is an extremely common technique for deriving qparams. It can be used for both weight and activation tensors. For weights, it observes the weight tensor in a data-independent way (which makes sense, because it’s a weight tensor), and calculates the qparams so as to minimize some loss between the floating-point and quantized tensor. This loss may be to minimize clamping error on the quantized tensor, or to minimize the L2-distance between the floating-point and quantized tensor, etc. Which loss to minimize will depend on which PTQ observer you are using, and you can build as many PTQ observers as you can imagine!

For activations, it’s the same thing, expect that we need calibration data. This makes sense, since activations, unlike weights, are data dependent. Again, I want to highlight that in int8 nomenclature, activations refer to the output features of a layer, not the activation function. So, to do PTQ on activations, we have to feed data through the model, and the same process as for weights will occur: the qparams will get updated with each forward pass. Some PTQ observers will only look at the latest forward pass, and some will aggregate statistics across multiple forward passes.

The reason that earlier we highlighted the importance of fusing modules for PTQ is because PTQ looks at each layer individually. In the case of activations, the PTQ observer will only see the floating-point activation, and the quantized activation. It does not see any immediate downstream transforms on the activation. Therefore, a PTQ observer for a Conv does not know that it may be followed by a ReLU activation function. In which case, the PTQ observer will probably assign negative quantization range for the output of the Conv, even though that negative quantization range will immediately go to waste when the following ReLU clamps all of the values to be at least 0. The benefit of fusing for PTQ is that the module will be seen as a whole. In our example, the ConvReLU will be PTQ’d as a unit, and will not assign negative quantization range to the activation as it will “see” the effect of the ReLU.

Quantization Aware Training (QAT):

QAT is also an extremely common technique for deriving quantization parameters. In fact, QAT does not just derive qparams, it updates all of the parameters in the model. It does so by applying the fake-quant operation to the weights and activations, and training in fake-quantized space, effectively fine-tuning the model for quantized space. As the weights are loaded, they are fake-quantized, and the activation of each layer, prior to being fed to the next layer, is also fake-quantized, i.e. clamped and rounded to the quantization grid of the layer in question.

You might wonder, how do you train a model in quantized space? Doesn’t the rounding kill the gradients, since it’s non-differentiable? QAT typically uses the Straight-Through Estimator (STE), which basically just ignores the rounding in the backwards pass. There are also other little tricks that one can use such as the popular Learned Step-Size Quantization (LSQ), which adjusts the qparam scale gradients so that dX_fq / ds increases as the distance between X_fq and a quantization transition point decreases, where X_fq is a quantized tensor and s is the scale qparam. I.e., the impact of the scale qparam on the output is larger the closer one is to a quantization bin boundary. It’s kind of like playing Marco-Polo as you are getting near the quantization transition point.

Step 4: Preparing the model for fake-quant

We’ve done all of the hard-work, now it gets much easier.

To transform the model from the fused-and-qconfigs-attached-model into a fake-quantizable model, we merely call:

torch.quantization.prepare(model, inplace=True)

This will read the qconfigs attached to each layer, and swap out the “floating-point” layers with quantizable layers. In practice, to take the example of a Conv, this will transform the forward call from (in pseudocode):

output = Conv(input, weight, bias)

to:

output = activation_post_proces(
Conv(
input,
weight_fake_quant(weight),
bias
)
)

Where activation_post_process is the forward call of a module / quantization object that does the requantization step on the activation, and weight_fake_quant is the forward call of a module that does the quantization of the weight tensor. For example, if one used the _LearnableFakeQuantize qconfig that I used earlier for both weights and activations, then the quantization step for both weights and activations (individually) would be done by the forward call specified here (I have added lots of comments below):

 def forward(self, X):
# Does PTQ on the layer if the `static_enabled` toggle is True
if self.static_enabled[0] == 1: # type: ignore[index]
# Feeds the tensor to the PTQ observer, e.g.
# layer.activation_post_process.activation_post_process(X.detach())
self.activation_post_process(X.detach())
# Calculate the qparams, based on whatever metric the PTQ observer is optimizing for
_scale, _zero_point = self.activation_post_process.calculate_qparams()
_scale = _scale.to(self.scale.device)
_zero_point = _zero_point.to(self.zero_point.device)
# Overwrite qparams with those derived from the observer
self.scale.data.copy_(_scale)
self.zero_point.data.copy_(_zero_point)
else:
# Ensures during training that the scale qparam doesn't learn an unsupported value,
# e.g. 0 or a negative value.
self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator]

# Fake-quantizes the tensor if the `fake_quant_enabled` toggle is True
if self.fake_quant_enabled[0] == 1:
# If symmetric quantization, it zeroes the zero-point
if self.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric):
self.zero_point.data.zero_()

# Grad scaling
if self.use_grad_scaling:
grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5
else:
grad_factor = 1.0

# Performs the fake-quantize operation (and backward pass
# calculation) in ATen C++ kernel
if self.qscheme in (
torch.per_channel_symmetric, torch.per_channel_affine):
X = torch._fake_quantize_learnable_per_channel_affine(
X, self.scale, self.zero_point, self.ch_axis,
self.quant_min, self.quant_max, grad_factor)
else:
X = torch._fake_quantize_learnable_per_tensor_affine(
X, self.scale, self.zero_point,
self.quant_min, self.quant_max, grad_factor)

return X

There’s a lot going on here, but the key things to note are that _LearnableFakeQuantize comes with some boolean toggles fake_quant_enabled, static_enabled that are stored as buffers in the model in module.activation_post_process for activations and module.weight_fake_quant for weights. The status of these toggles will dictate if we do PTQ or fake-quantize the tensor in question. The quantization object is also where the qparams are stored in the model state-dict, e.g. layer.activation_post_process.scale for the activation scale qparam.

Furthermore, the PTQ observers are attached to the quantization objects. The naming convention is a bit confusing, as the PTQ observers are also called activation_post_process. This means that the PTQ observer of the activation of layer will live in layer.activation_post_process.activation_post_process, and that of the weight will live in layer.weight_fake_quant.activation_post_process. I suspect this is because of this makes the code easily compatible with dynamic quantization, where activations only have a PTQ observer and no qparams that are independent of the observer. However, I do think it may be quite confusing for beginners in static quantization.

You can toggle the state of the quantization objects with provided methods that vary depending on the qconfig / class of the quantization object, e.g. enable_static_observation in the _LearnableFakeQuantizeclass turns on PTQ, turns off gradients on the learned qparams, and enables fake-quantization. It is also trivial to create one’s own toggling methods, if there is something in particular you would like to achieve: one merely needs to change the value of the buffer to either 0 or 1. An easy way to enable fake quant, or disable it, throughout the entire model is to use model.apply(torch.quantization.enable_fake_quant) or model.apply(torch.quantization.disable_fake_quant) , as this will iterate through your model and turn fake quant on/off for each layer.

I will provide a warning about the native PyTorch quantization objects: the _LearnableFakeQuantizeclass uses a toggle called static_enabled, instead of the observer_enabled toggle that is used in FakeQuantize and other quantization-forward classes to control whether or not one does PTQ. This mismatch can cause problems if one assumes that the same toggling methods will work for _LearnableFakeQuantize and FakeQuantize. I presume the difference exists to highlight that the PTQ forward calls work slightly different between the two, but regardless it warrants highlighting that there is a difference that one needs to be aware of. I am working on a backwards-compatible PR to eliminate this issue, but in the meantime (or if it is not accepted), then just be aware of the issue. (I will edit this post if the PR is accepted or the situation changes).

Step 5: Converting the model to “true” int8

There’s a reason the fake-quant model is called “fake” quant. This is because it is still, in terms of data types, a floating-point model. The floating point values are clamped and rounded onto the quantized grid. However, these are still floating point numbers, they just belong to a grid, i.e.

X_fq 𝜖 {q_min, q_min + scale, …, q_min + (255–1)*scale, q_max}

for affine int8 quantization, where X_fq is the fake quantized tensor of floating point tensorX, q_min = 0 — (zero_point * scale) and q_max = q_min + 255*scale.

This fake-quant space is useful for training and manipulating the model. This is partly because PyTorch is optimized for fp32 numbers, but mainly because of the need for differentiability when training the model in QAT.

So, fake quant is great for training, but we do eventually want a “true” quantized model, e.g. an int8 model, that only uses 8 bits to represent the values. In PyTorch, this is as simple as calling:

backend = 'qnnpack'
torch.backends.quantized.engine = backend
fake_quant_model.to('cpu')
converted_model = torch.quantization.convert(fake_quant_model)

You are not required to use the qnnpack backend, it’s just what I tend to use.

Conversion issues:

When converting a model, you may run into an error if you have non-leaf tensors sitting in your model state-dict. This can happen if you are using the state-dict instead of a global variable to pass intermediate feature tensors from within your model to another location. For example, this may happen if you are using those features for some special training loss, or as part of your model functionality. This will not allow the model to convert, because the internal deepcopy call will fail. I would recommend removing the non-leaf tensor from the state-dict before converting the model, and then the model should convert correctly.

However, the converted model may still throw an error once you try to feed data through it. Common issues are:

  • Not having the model on the same device as the input tensor, e.g. the converted model lives on CPU and the input tensor lives on CUDA. The solution is to transfer the input tensor to the same device first. For a converted model, both model and inputted data should generally be on the CPU (e.g. qnnpack requires this).
  • Using *, +, -, /and concat instead of FFs.
  • Not having Quant/DeQuant stubs in the correct places. E.g. you are feeding a floating-point tensor to a quantized layer before passing it through a QuantStub first, or you try to feed an already quantized tensor through a QuantStub, etc. The solution is to correctly place your Quant/DeQuant stubs, but the error messages will highlight where the problem is.
  • Not having assigned qconfigs to a layer that you expect to be quantized.

Generally speaking though, if all has gone correctly, then the converted model should give almost exactly the same performance as the fake-quant model. There may be slight discrepancies due to minor floating-point differences in the fake-quant model, or due to CUDA non-determinism if running on CUDA, but these differences should not be significant.

Step 6: Unit Testing: it will save you heartache down the line

Unit Testing may not be the sexiest topic, but it is one that I am quite passionate about. It is generally not considered a mandatory part of neural network quantization, but I really think it should be. Neural network quantization has a steep learning curve, and on top of that, things tend to fail silently, which can be a disaster for your quantization. GPUs are also not cheap, so it’s important that your experiments succeed and you catch bugs as soon as possible. Therefore it is important that you make your quantization code robust, and make it so that things do not fail silently. I’ve been doing neural network quantization for quite some time. I’ve done something like a thousand int8 experiments, and I authored the quantization codebase at my company. I have seen all sorts of bugs and things that failed silently, and here a few examples (most of these were my doing):

  • Layers not fusing correctly, or causing certain “shared” layers to be set to Identity unexpectedly (e.g. the ReLU in the ResNet example I shared earlier).
  • Your PTQ observers being on when they shouldn’t be, therefore over-writing your qparams when you didn’t expect them to.
  • Your quantization-state being incorrect: layers may have fake-quant disabled when it should be enabled, and vice-versa, including for whatever custom quantization forward/backward calls you may experiment with. Even for PyTorch’s native quantization objects, they use the same class methods that will either do the thing you expect them to, or do nothing, depending on which class you are using (e.g. the observer_enabled vs static_enabled discrepancy between FakeQuantize and _LearnableFakeQuantize quantizers).
  • QParams not getting updated during QAT, e.g. you failed to include the qparams in the optimizer, because you initialized your optimizer prior to preparing the model for fake-quant, during which the qparams were generated and attached to your model.
  • QParam gradients getting toggled on/off when they shouldn’t be.
  • Other miscellaneous issues in how quantization interacts with your specific model, or your custom quantization techniques. Unfortunately I can’t share the most interesting examples of these, because they’re IP and/or trade secrets. Suffice to say, unit tests are very useful!

In many ways, working with neural network quantization involves opening up the black box of your neural network and trying to diagnose what is happening inside. In doing so, you will likely come across things happening that you weren’t expecting or hoping for, and unit tests can help quickly catch those issues the next time they occur. It also protects your quantization code from being used “inappropriately” by others running experiments. Quantization-specific visualization code can also be very useful, but that’s another blog for another time!

That is it! In summary, there are 6 major steps to getting a quantized int8 model that can run on hardware more efficiently. These are:

  1. Required Architecture Changes (Quant/DeQuant stubs and FFs).
  2. Fusing modules. Optional, but highly recommended.
  3. Attaching QConfigs.
  4. Preparing the model for quantization, i.e. fake-quant.
  5. Converting the model to a “true” int8 model.
  6. Unit Testing: it will save you heartache down the line. Optional, but highly recommended.

There are many tricks to getting quantized models to perform just as well as floating point models, but as I mentioned earlier, that is the topic of large amounts of public and private research. There are also many ecosystems other than PyTorch that you will need to familiarize yourself with, depending on what your device target is for your quantized neural network, e.g. OpenVino (which I would personally recommend) or Neural Compressor for Intel, Qualcomm’s AI Engine Direct, NVIDIA’s TensorRT, and ONNX to rule them all! (For those of you who don’t know, ONNX is a very common intermediary between PyTorch / other DL libraries and hardware SDKs like OpenVino.)

I hope this post was useful, and if you have any questions or want some help in quantizing your model, please leave a comment or get in touch!

Best regards, and many thanks,
Oscar Wiljam Savolainen, PhD

LinkedIn: https://www.linkedin.com/in/oscar-savolainen-b88277121/
My AI educational channel: https://www.youtube.com/@OscarSavolainen

--

--

Oscar Wiljam Savolainen, PhD
Oscar Wiljam Savolainen, PhD

Written by Oscar Wiljam Savolainen, PhD

Machine Learning Research Engineer, with a specialization in neural network quantization. PhD in neural signal processing from Imperial College London.

Responses (2)