exporter Module

Allows to export a quantized model in different serialization formats and quantization formats. For more details about the export formats and options, please refer to the project’s GitHub README file. If you have any questions or issues, please open an issue in this GitHub repository.

QuantizationFormat

class model_compression_toolkit.exporter.QuantizationFormat(value)

Specify which quantization format to use for exporting a quantized model.

FAKELY_QUANT - Weights and activations are quantized but represented using float data type.

INT8 - Weights and activations are represented using 8-bit integer data type.

MCTQ - Weights and activations are quantized using mct_quantizers custom quantizers.

KerasExportSerializationFormat

Select the serialization format for exporting a quantized Keras model.

class model_compression_toolkit.exporter.KerasExportSerializationFormat(value)

Specify which serialization format to use for exporting a quantized Keras model.

KERAS - .keras file format

TFLITE - .tflite file format

keras_export_model

Allows to export a Keras model that was quantized via MCT.

class model_compression_toolkit.exporter.keras_export_model(model, save_model_path, is_layer_exportable_fn=is_keras_layer_exportable, serialization_format=KerasExportSerializationFormat.KERAS, quantization_format=QuantizationFormat.MCTQ)

Export a Keras quantized model to a .keras or .tflite format model (according to serialization_format). The model will be saved to the path in save_model_path. Models that are exported to .keras format can use quantization_format of QuantizationFormat.MCTQ or QuantizationFormat.FAKELY_QUANT. Models that are exported to .tflite format can use quantization_format of QuantizationFormat.INT8 or QuantizationFormat.FAKELY_QUANT.

Parameters:
  • model – Model to export.

  • save_model_path – Path to save the model.

  • is_layer_exportable_fn – Callable to check whether a layer can be exported or not.

  • serialization_format – Format to export the model according to (KerasExportSerializationFormat.KERAS, by default).

  • quantization_format – Format of how quantizers are exported (MCTQ quantizers, by default).

Returns:

Custom objects dictionary needed to load the model.

Return type:

Dict[str, type]

Keras Tutorial

To export a TensorFlow model as a quantized model, it is necessary to first apply quantization to the model using MCT:

import numpy as np
from keras.applications import ResNet50
import model_compression_toolkit as mct

# Create a model
float_model = ResNet50()
# Quantize the model.
# Notice that here the representative dataset is random for demonstration only.
quantized_exportable_model, _ = mct.ptq.keras_post_training_quantization(float_model,
                                                                         representative_data_gen=lambda: [np.random.random((1, 224, 224, 3))])

keras serialization format

The model will be exported as a tensorflow .keras model where weights and activations are quantized but represented using a float32 dtype. Two optional quantization formats are available: MCTQ and FAKELY_QUANT.

MCTQ

By default, mct.exporter.keras_export_model will export the quantized Keras model to a .keras model with custom quantizers from mct_quantizers module.

import tempfile

# Path of exported model
_, keras_file_path = tempfile.mkstemp('.keras')

# Export a keras model with mctq custom quantizers.
mct.exporter.keras_export_model(model=quantized_exportable_model,
                                save_model_path=keras_file_path)

Notice that the model has the same size as the quantized exportable model as weights data types are float.

PytorchExportSerializationFormat

Select the serialization format for exporting a quantized Pytorch model.

class model_compression_toolkit.exporter.PytorchExportSerializationFormat(value)

Specify which serialization format to use for exporting a quantized Pytorch model.

TORCHSCRIPT - torchscript format

ONNX - onnx format

pytorch_export_model

Allows to export a Pytorch model that was quantized via MCT.

class model_compression_toolkit.exporter.pytorch_export_model(model, save_model_path, repr_dataset, is_layer_exportable_fn=is_pytorch_layer_exportable, serialization_format=PytorchExportSerializationFormat.ONNX, quantization_format=QuantizationFormat.MCTQ, onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION)

Export a PyTorch quantized model to a torchscript or onnx model. The model will be saved to the path in save_model_path. Currently, pytorch_export_model supports only QuantizationFormat.FAKELY_QUANT (where weights and activations are float fakely-quantized values) and PytorchExportSerializationFormat.TORCHSCRIPT (where the model will be saved to TorchScript model) or PytorchExportSerializationFormat.ONNX (where the model will be saved to ONNX model).

Parameters:
  • model – Model to export.

  • save_model_path – Path to save the model.

  • repr_dataset – Representative dataset for tracing the pytorch model (mandatory for exporting it).

  • is_layer_exportable_fn – Callable to check whether a layer can be exported or not.

  • serialization_format – Format to export the model according to (by default

  • PytorchExportSerializationFormat.ONNX).

  • quantization_format – Format of how quantizers are exported (fakely-quant, int8, MCTQ quantizers).

  • onnx_opset_version – ONNX opset version to use for exported ONNX model.

Return type:

None

Pytorch Tutorial

In order to export your quantized model to ONNX format, and use it for inference, some additional packages are needed. Notice, this is needed only for models exported to ONNX format, so this part can be skipped if this is not planned:

! pip install -q onnx onnxruntime onnxruntime-extensions

Now, let’s start the export demonstration by quantizing the model using MCT:

import model_compression_toolkit as mct
import numpy as np
import torch
from torchvision.models.mobilenetv2 import mobilenet_v2

# Create a model
float_model = mobilenet_v2()


# Notice that here the representative dataset is random for demonstration only.
def representative_data_gen():
    yield [np.random.random((1, 3, 224, 224))]


quantized_exportable_model, _ = mct.ptq.pytorch_post_training_quantization(float_model, representative_data_gen=representative_data_gen)

ONNX

The model will be exported in ONNX format where weights and activations are represented as float. Notice that onnx should be installed in order to export the model to an ONNX model.

There are two optional formats to choose: MCTQ or FAKELY_QUANT.

MCTQ Quantization Format

By default, mct.exporter.pytorch_export_model will export the quantized pytorch model to an ONNX model with custom quantizers from mct_quantizers module.

# Path of exported model
onnx_file_path = 'model_format_onnx_mctq.onnx'

# Export ONNX model with mctq quantizers.
mct.exporter.pytorch_export_model(model=quantized_exportable_model,
                                  save_model_path=onnx_file_path,
                                  repr_dataset=representative_data_gen)

Notice that the model has the same size as the quantized exportable model as weights data types are float.

ONNX opset version

By default, the used ONNX opset version is 15, but this can be changed using onnx_opset_version:

# Export ONNX model with mctq quantizers.
mct.exporter.pytorch_export_model(model=quantized_exportable_model,
                                  save_model_path=onnx_file_path,
                                  repr_dataset=representative_data_gen,
                                  onnx_opset_version=16)

Use exported model for inference

To load and infer using the exported model, which was exported to an ONNX file in MCTQ format, we will use mct_quantizers method get_ort_session_options during onnxruntime session creation. Notice, inference on models that are exported in this format are slowly and suffers from longer latency. However, inference of these models on IMX500 will not suffer from this issue.

import mct_quantizers as mctq
import onnxruntime as ort

sess = ort.InferenceSession(onnx_file_path,
                            mctq.get_ort_session_options(),
                            providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

_input_data = next(representative_data_gen())[0].astype(np.float32)
_model_output_name = sess.get_outputs()[0].name
_model_input_name = sess.get_inputs()[0].name

# Run inference
predictions = sess.run([_model_output_name], {_model_input_name: _input_data})