Pytorch Gradient Based Post Training Quantization

model_compression_toolkit.gptq.pytorch_gradient_post_training_quantization(model, representative_data_gen, target_resource_utilization=None, core_config=CoreConfig(), gptq_config=None, gptq_representative_data_gen=None, target_platform_capabilities=DEFAULT_PYTORCH_TPC)

Quantize a trained Pytorch module using post-training quantization. By default, the module is quantized using a symmetric constraint quantization thresholds (power of two) as defined in the default TargetPlatformCapabilities. The module is first optimized using several transformations (e.g. BatchNormalization folding to preceding layers). Then, using a given dataset, statistics (e.g. min/max, histogram, etc.) are being collected for each layer’s output (and input, depends on the quantization configuration). Thresholds are then being calculated using the collected statistics and the module is quantized (both coefficients and activations by default). If gptq_config is passed, the quantized weights are optimized using gradient based post training quantization by comparing points between the float and quantized modules, and minimizing the observed loss. Then, the quantized weights are optimized using gradient based post training quantization by comparing points between the float and quantized models, and minimizing the observed loss.

Parameters:
  • model (Module) – Pytorch model to quantize.

  • representative_data_gen (Callable) – Dataset used for calibration.

  • target_resource_utilization (ResourceUtilization) – ResourceUtilization object to limit the search of the mixed-precision configuration as desired.

  • core_config (CoreConfig) – Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.

  • gptq_config (GradientPTQConfig) – Configuration for using gptq (e.g. optimizer).

  • gptq_representative_data_gen (Callable) – Dataset used for GPTQ training. If None defaults to representative_data_gen

  • target_platform_capabilities (TargetPlatformCapabilities) – TargetPlatformCapabilities to optimize the PyTorch model according to.

Returns:

A quantized module and information the user may need to handle the quantized module.

Examples

Import Model Compression Toolkit:

>>> import model_compression_toolkit as mct

Import a Pytorch module:

>>> from torchvision import models
>>> module = models.mobilenet_v2()

Create a random dataset generator, for required number of calibration iterations (num_calibration_batches): In this example a random dataset of 10 batches each containing 4 images is used.

>>> import numpy as np
>>> num_calibration_batches = 10
>>> def repr_datagen():
>>>     for _ in range(num_calibration_batches):
>>>         yield [np.random.random((4, 3, 224, 224))]

Create MCT core configurations with number of calibration iterations set to 1:

>>> config = mct.core.CoreConfig()

Pass the module, the representative dataset generator and the configuration (optional) to get a quantized module

>>> quantized_module, quantization_info = mct.gptq.pytorch_gradient_post_training_quantization(module, repr_datagen, core_config=config, gptq_config=gptq_conf)