PyTorch Quantization Aware Training Model Init

model_compression_toolkit.qat.pytorch_quantization_aware_training_init_experimental(in_model, representative_data_gen, target_resource_utilization=None, core_config=CoreConfig(), qat_config=QATConfig(), target_platform_capabilities=DEFAULT_PYTORCH_TPC)

Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is quantized using a symmetric quantization thresholds (power of two). The model is first optimized using several transformations (e.g. BatchNormalization folding to preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are being collected for each layer’s output (and input, depends on the quantization configuration). For each possible bit width (per layer) a threshold is then being calculated using the collected statistics. Then, if given a mixed precision config in the core_config, using an ILP solver we find a mixed-precision configuration, and set a bit-width for each layer. The model is built with fake_quant nodes for quantizing activation. Weights are kept as float and are quantized online while training by the quantization wrapper’s weight quantizer. In order to limit the maximal model’s size, a target resource utilization need to be passed after weights_memory is set (in bytes).

Parameters:
  • in_model (Model) – Pytorch model to quantize.

  • representative_data_gen (Callable) – Dataset used for initial calibration.

  • target_resource_utilization (ResourceUtilization) – ResourceUtilization object to limit the search of the mixed-precision configuration as desired.

  • core_config (CoreConfig) – Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.

  • qat_config (QATConfig) – QAT configuration

  • target_platform_capabilities (TargetPlatformCapabilities) – TargetPlatformCapabilities to optimize the Pytorch model according to.

Returns:

A quantized model. User information that may be needed to handle the quantized model.

Examples

Import MCT:

>>> import model_compression_toolkit as mct

Import a Pytorch model:

>>> from torchvision.models import mobilenet_v2
>>> model = mobilenet_v2(pretrained=True)

Create a random dataset generator, for required number of calibration iterations (num_calibration_batches): In this example a random dataset of 10 batches each containing 4 images is used.

>>> import numpy as np
>>> num_calibration_batches = 10
>>> def repr_datagen():
>>>     for _ in range(num_calibration_batches):
>>>         yield [np.random.random((4, 3, 224, 224))]

Create a MCT core config, containing the quantization configuration:

>>> config = mct.core.CoreConfig()

Pass the model, the representative dataset generator, the configuration and the target resource utilization to get a quantized model. Now the model contains quantizer wrappers for fine tunning the weights:

>>> quantized_model, quantization_info = mct.qat.pytorch_quantization_aware_training_init_experimental(model, repr_datagen, core_config=config)

For more configuration options, please take a look at our API documentation.