PyTorch Quantization Aware Training Model Init¶
- model_compression_toolkit.qat.pytorch_quantization_aware_training_init_experimental(in_model, representative_data_gen, target_resource_utilization=None, core_config=CoreConfig(), qat_config=QATConfig(), target_platform_capabilities=DEFAULT_PYTORCH_TPC)¶
Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized with post-training quantization, then the model layers are wrapped with QuantizeWrappers. The model is quantized using a symmetric quantization thresholds (power of two). The model is first optimized using several transformations (e.g. BatchNormalization folding to preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are being collected for each layer’s output (and input, depends on the quantization configuration). For each possible bit width (per layer) a threshold is then being calculated using the collected statistics. Then, if given a mixed precision config in the core_config, using an ILP solver we find a mixed-precision configuration, and set a bit-width for each layer. The model is built with fake_quant nodes for quantizing activation. Weights are kept as float and are quantized online while training by the quantization wrapper’s weight quantizer. In order to limit the maximal model’s size, a target resource utilization need to be passed after weights_memory is set (in bytes).
- Parameters:
in_model (Model) – Pytorch model to quantize.
representative_data_gen (Callable) – Dataset used for initial calibration.
target_resource_utilization (ResourceUtilization) – ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
core_config (CoreConfig) – Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
qat_config (QATConfig) – QAT configuration
target_platform_capabilities (TargetPlatformCapabilities) – TargetPlatformCapabilities to optimize the Pytorch model according to.
- Returns:
A quantized model. User information that may be needed to handle the quantized model.
Examples
Import MCT:
>>> import model_compression_toolkit as mct
Import a Pytorch model:
>>> from torchvision.models import mobilenet_v2 >>> model = mobilenet_v2(pretrained=True)
Create a random dataset generator, for required number of calibration iterations (num_calibration_batches): In this example a random dataset of 10 batches each containing 4 images is used.
>>> import numpy as np >>> num_calibration_batches = 10 >>> def repr_datagen(): >>> for _ in range(num_calibration_batches): >>> yield [np.random.random((4, 3, 224, 224))]
Create a MCT core config, containing the quantization configuration:
>>> config = mct.core.CoreConfig()
Pass the model, the representative dataset generator, the configuration and the target resource utilization to get a quantized model. Now the model contains quantizer wrappers for fine tunning the weights:
>>> quantized_model, quantization_info = mct.qat.pytorch_quantization_aware_training_init_experimental(model, repr_datagen, core_config=config)
For more configuration options, please take a look at our API documentation.