Pytorch Post Training Quantization¶
- model_compression_toolkit.ptq.pytorch_post_training_quantization(in_module, representative_data_gen, target_resource_utilization=None, core_config=CoreConfig(), target_platform_capabilities=DEFAULT_PYTORCH_TPC)¶
Quantize a trained Pytorch module using post-training quantization. By default, the module is quantized using a symmetric constraint quantization thresholds (power of two) as defined in the default TargetPlatformCapabilities. The module is first optimized using several transformations (e.g. BatchNormalization folding to preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are being collected for each layer’s output (and input, depends on the quantization configuration). Thresholds are then being calculated using the collected statistics and the module is quantized (both coefficients and activations by default). If gptq_config is passed, the quantized weights are optimized using gradient based post training quantization by comparing points between the float and quantized modules, and minimizing the observed loss.
- Parameters:
in_module (Module) – Pytorch module to quantize.
representative_data_gen (Callable) – Dataset used for calibration.
target_resource_utilization (ResourceUtilization) – ResourceUtilization object to limit the search of the mixed-precision configuration as desired.
core_config (CoreConfig) – Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
target_platform_capabilities (TargetPlatformCapabilities) – TargetPlatformCapabilities to optimize the PyTorch model according to.
- Returns:
A quantized module and information the user may need to handle the quantized module.
Examples
Import a Pytorch module:
>>> from torchvision import models >>> module = models.mobilenet_v2()
Create a random dataset generator, for required number of calibration iterations (num_calibration_batches): In this example a random dataset of 10 batches each containing 4 images is used.
>>> import numpy as np >>> num_calibration_batches = 10 >>> def repr_datagen(): >>> for _ in range(num_calibration_batches): >>> yield [np.random.random((4, 3, 224, 224))]
Import MCT and pass the module with the representative dataset generator to get a quantized module Set number of clibration iterations to 1:
>>> import model_compression_toolkit as mct >>> quantized_module, quantization_info = mct.ptq.pytorch_post_training_quantization(module, repr_datagen)