Pytorch Structured Pruning

model_compression_toolkit.pruning.pytorch_pruning_experimental(model, target_resource_utilization, representative_data_gen, pruning_config=PruningConfig(), target_platform_capabilities=DEFAULT_PYOTRCH_TPC)

Perform structured pruning on a Pytorch model to meet a specified target resource utilization. This function prunes the provided model according to the target resource utilization by grouping and pruning channels based on each layer’s SIMD configuration in the Target Platform Capabilities (TPC). By default, the importance of each channel group is determined using the Label-Free Hessian (LFH) method, assessing each channel’s sensitivity to the Hessian of the loss function. This pruning strategy considers groups of channels together for a more hardware-friendly architecture. The process involves analyzing the model with a representative dataset to identify groups of channels that can be removed with minimal impact on performance.

Notice that the pruned model must be retrained to recover the compressed model’s performance.

Parameters:
  • model (Module) – The PyTorch model to be pruned.

  • target_resource_utilization (ResourceUtilization) – Key Performance Indicators specifying the pruning targets.

  • representative_data_gen (Callable) – A function to generate representative data for pruning analysis.

  • pruning_config (PruningConfig) – Configuration settings for the pruning process. Defaults to standard config.

  • target_platform_capabilities (TargetPlatformCapabilities) – Platform-specific constraints and capabilities. Defaults to DEFAULT_PYTORCH_TPC.

Returns:

A tuple containing the pruned Pytorch model and associated pruning information.

Return type:

Tuple[Model, PruningInfo]

Note

The pruned model should be fine-tuned or retrained to recover or improve its performance post-pruning.

Examples

Import MCT:

>>> import model_compression_toolkit as mct

Import a Pytorch model:

>>> from torchvision.models import resnet50, ResNet50_Weights
>>> model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

Create a random dataset generator:

>>> import numpy as np
>>> def repr_datagen(): yield [np.random.random((1, 3, 224, 224))]

Define a target resource utilization for pruning. Here, we aim to reduce the memory footprint of weights by 50%, assuming the model weights are represented in float32 data type (thus, each parameter is represented using 4 bytes):

>>> dense_nparams = sum(p.numel() for p in model.state_dict().values())
>>> target_resource_utilization = mct.core.ResourceUtilization(weights_memory=dense_nparams * 4 * 0.5)

Optionally, define a pruning configuration. num_score_approximations can be passed to configure the number of importance scores that will be calculated for each channel. A higher value for this parameter yields more precise score approximations but also extends the duration of the pruning process:

>>> pruning_config = mct.pruning.PruningConfig(num_score_approximations=1)

Perform pruning:

>>> pruned_model, pruning_info = mct.pruning.pytorch_pruning_experimental(model=model, target_resource_utilization=target_resource_utilization, representative_data_gen=repr_datagen, pruning_config=pruning_config)
Return type:

Tuple[Module, PruningInfo]