sony_custom_layers.keras

 1# -----------------------------------------------------------------------------
 2# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
 3#
 4# Licensed under the Apache License, Version 2.0 (the "License");
 5# you may not use this file except in compliance with the License.
 6# You may obtain a copy of the License at
 7#
 8#     http://www.apache.org/licenses/LICENSE-2.0
 9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# -----------------------------------------------------------------------------
16
17from sony_custom_layers.util.import_util import validate_installed_libraries
18from sony_custom_layers import required_libraries
19
20validate_installed_libraries(required_libraries['tf'])
21
22from .object_detection import FasterRCNNBoxDecode, SSDPostProcess, ScoreConverter    # noqa: E402
23from .custom_objects import custom_layers_scope    # noqa: E402
24
25__all__ = ['FasterRCNNBoxDecode', 'ScoreConverter', 'SSDPostProcess', 'custom_layers_scope']
@register_layer
class FasterRCNNBoxDecode(sony_custom_layers.keras.base_custom_layer.CustomLayer):
 28@register_layer
 29class FasterRCNNBoxDecode(CustomLayer):
 30    """
 31    Box decoding as per Faster R-CNN <https://arxiv.org/abs/1506.01497>.
 32
 33    Args:
 34        anchors: Anchors with a shape of (n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).
 35        scale_factors: Scaling factors in the format (y, x, height, width).
 36        clip_window: Clipping window in the format (y_min, x_min, y_max, x_max).
 37
 38    Inputs:
 39        **rel_codes** (Tensor): Relative codes (encoded offsets) with a shape of (batch, n_boxes, 4) in centroid
 40                                coordinates (y_center, x_center, h, w).
 41
 42    Returns:
 43        Decoded boxes with a shape of (batch, n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).
 44
 45    Raises:
 46        ValueError: If provided with invalid arguments or an input tensor with unexpected shape
 47
 48    Example:
 49        ```
 50        from sony_custom_layers.keras import FasterRCNNBoxDecode
 51
 52        box_decode = FasterRCNNBoxDecode(anchors,
 53                                         scale_factors=(10, 10, 5, 5),
 54                                         clip_window=(0, 0, 1, 1))
 55        decoded_boxes = box_decode(rel_codes)
 56        ```
 57    """
 58
 59    def __init__(self, anchors: Union[np.ndarray, tf.Tensor, List[List[float]]],
 60                 scale_factors: Sequence[Union[float, int]], clip_window: Sequence[Union[float, int]], **kwargs):
 61        super().__init__(**kwargs)
 62        anchors = tf.constant(anchors)
 63        if not (len(anchors.shape) == 2 and anchors.shape[-1] == 4):
 64            raise ValueError(f'Invalid anchors shape {anchors.shape}. Expected shape (n_boxes, 4).')
 65        self.anchors = anchors
 66
 67        if len(scale_factors) != 4:
 68            raise ValueError(f'Invalid scale factors {scale_factors}. Expected 4 values for (y, x, height, width).')
 69        self.scale_factors = tf.constant(scale_factors, dtype=tf.float32)
 70
 71        if len(clip_window) != 4:
 72            raise ValueError(f'Invalid clip window {clip_window}. Expected 4 values for (y_min, x_min, y_max, x_max).')
 73        self.clip_window = clip_window
 74
 75    def call(self, rel_codes: tf.Tensor, *args, **kwargs) -> tf.Tensor:
 76        """ """
 77        if len(rel_codes.shape) != 3 or rel_codes.shape[-1] != 4:
 78            raise ValueError(f'Invalid input tensor shape {rel_codes.shape}. Expected shape (batch, n_boxes, 4).')
 79        if rel_codes.shape[-2] != self.anchors.shape[-2]:
 80            raise ValueError(f'Mismatch in the number of boxes between input tensor ({rel_codes.shape[-2]}) '
 81                             f'and anchors ({self.anchors.shape[-2]})')
 82
 83        scaled_codes = rel_codes / self.scale_factors
 84
 85        a_y_min, a_x_min, a_y_max, a_x_max = tf.unstack(self.anchors, axis=-1)
 86        a_y_center, a_x_center, a_h, a_w = corners_to_centroids(a_y_min, a_x_min, a_y_max, a_x_max)
 87
 88        box_y_center = scaled_codes[..., 0] * a_h + a_y_center
 89        box_x_center = scaled_codes[..., 1] * a_w + a_x_center
 90        box_h = tf.exp(scaled_codes[..., 2]) * a_h
 91        box_w = tf.exp(scaled_codes[..., 3]) * a_w
 92        box_y_min, box_x_min, box_y_max, box_x_max = centroids_to_corners(box_y_center, box_x_center, box_h, box_w)
 93        boxes = tf.stack([box_y_min, box_x_min, box_y_max, box_x_max], axis=-1)
 94
 95        y_low, x_low, y_high, x_high = self.clip_window
 96        boxes = tf.clip_by_value(boxes, [y_low, x_low, y_low, x_low], [y_high, x_high, y_high, x_high])
 97        return boxes
 98
 99    def get_config(self) -> dict:
100        """ """
101        config = super().get_config()
102        config.update({
103            'anchors': self.anchors.numpy().tolist(),
104            'scale_factors': self.scale_factors.numpy().tolist(),
105            'clip_window': self.clip_window,
106        })
107        return config

Box decoding as per Faster R-CNN https://arxiv.org/abs/1506.01497.

Arguments:
  • anchors: Anchors with a shape of (n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).
  • scale_factors: Scaling factors in the format (y, x, height, width).
  • clip_window: Clipping window in the format (y_min, x_min, y_max, x_max).
Inputs:

rel_codes (Tensor): Relative codes (encoded offsets) with a shape of (batch, n_boxes, 4) in centroid coordinates (y_center, x_center, h, w).

Returns:

Decoded boxes with a shape of (batch, n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).

Raises:
  • ValueError: If provided with invalid arguments or an input tensor with unexpected shape
Example:
from sony_custom_layers.keras import FasterRCNNBoxDecode

box_decode = FasterRCNNBoxDecode(anchors,
                                 scale_factors=(10, 10, 5, 5),
                                 clip_window=(0, 0, 1, 1))
decoded_boxes = box_decode(rel_codes)
class ScoreConverter(builtins.str, enum.Enum):
21class ScoreConverter(str, Enum):
22    # values are compatible with keras activation interface
23    LINEAR = 'linear'
24    SIGMOID = 'sigmoid'
25    SOFTMAX = 'softmax'

An enumeration.

@register_layer
class SSDPostProcess(sony_custom_layers.keras.base_custom_layer.CustomLayer):
 44@register_layer
 45class SSDPostProcess(CustomLayer):
 46    """
 47    SSD Post Processing, based on <https://arxiv.org/abs/1512.02325>.
 48
 49    Args:
 50        anchors (Tensor | np.ndarray): Anchors with a shape of (n_boxes, 4) in corner coordinates
 51                                       (y_min, x_min, y_max, x_max).
 52        scale_factors (list | tuple): Box decoding scaling factors in the format (y, x, height, width).
 53        clip_size (list | tuple): Clipping size in the format (height, width). The decoded boxes are clipped to the
 54                                  range y=[0, height] and x=[0, width]. Typically, the clipping size is (1, 1) for
 55                                  normalized boxes and the image size for boxes in pixel coordinates.
 56        score_converter (ScoreConverter): Conversion to apply to the input logits (sigmoid, softmax, or linear).
 57        score_threshold (float): Score threshold for non-maximum suppression.
 58        iou_threshold (float): Intersection over union threshold for non-maximum suppression.
 59        max_detections (int): The number of detections to return.
 60        remove_background (bool) : If True, the first class is removed from the input scores (after the score_converter
 61                                   is applied).
 62
 63    Inputs:
 64        A list or tuple of:
 65        - **rel_codes** (Tensor): Relative codes (encoded offsets) with a shape of (batch, n_boxes, 4) in centroid
 66                            coordinates (y_center, x_center, w, h).
 67        - **scores** (Tensor): Scores or logits with a shape of (batch, n_boxes, n_labels).
 68
 69    Returns:
 70        'CombinedNonMaxSuppression' named tuple:
 71        - nmsed_boxes: Selected boxes sorted by scores in descending order, with a shape of
 72                         (batch, max_detections, 4),in corner coordinates (y_min, x_min, y_max, x_max).
 73        - nmsed_scores: Scores corresponding to the selected boxes, with a shape of (batch, max_detections).
 74        - nmsed_classes: Labels corresponding to the selected boxes, with a shape of (batch, max_detections).
 75                           Each label corresponds to the class index of the selected score in the input scores.
 76        - valid_detections: The number of valid detections out of max_detections.
 77
 78    Raises:
 79        ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes.
 80
 81    Example:
 82        ```
 83        from sony_custom_layers.keras import SSDPostProcessing, ScoreConverter
 84
 85        post_process = SSDPostProcess(anchors=anchors,
 86                                      scale_factors=(10, 10, 5, 5),
 87                                      clip_size=(320, 320),
 88                                      score_converter=ScoreConverter.SIGMOID,
 89                                      score_threshold=0.01,
 90                                      iou_threshold=0.6,
 91                                      max_detections=200,
 92                                      remove_background=True)
 93        res = post_process([rel_codes, logits])
 94        boxes = res.nmsed_boxes
 95        ```
 96    """
 97
 98    def __init__(self,
 99                 anchors: Union[np.ndarray, tf.Tensor, List[List[float]]],
100                 scale_factors: Sequence[Union[int, float]],
101                 clip_size: Sequence[Union[int, float]],
102                 score_converter: Union[ScoreConverter, str],
103                 score_threshold: float,
104                 iou_threshold: float,
105                 max_detections: int,
106                 remove_background: bool = False,
107                 **kwargs):
108        """ """
109        super().__init__(**kwargs)
110
111        if not 0 <= score_threshold <= 1:
112            raise ValueError(f'Invalid score_threshold {score_threshold} not in range [0, 1]')
113        if not 0 <= iou_threshold <= 1:
114            raise ValueError(f'Invalid iou_threshold {iou_threshold} not in range [0, 1]')
115        if max_detections <= 0:
116            raise ValueError(f'Invalid non-positive max_detections {max_detections}')
117
118        self.cfg = SSDPostProcessCfg(anchors=anchors,
119                                     scale_factors=scale_factors,
120                                     clip_size=clip_size,
121                                     score_converter=score_converter,
122                                     score_threshold=score_threshold,
123                                     iou_threshold=iou_threshold,
124                                     max_detections=max_detections,
125                                     remove_background=remove_background)
126        self._box_decode = FasterRCNNBoxDecode(anchors, scale_factors, (0, 0, *clip_size))
127
128    def call(self, inputs: Sequence[tf.Tensor], *args, **kwargs) -> Tuple[tf.Tensor]:
129        """ """
130        rel_codes, scores = inputs
131        if len(rel_codes.shape) != 3 and rel_codes.shape[-1] != 4:
132            raise ValueError(f'Invalid input offsets shape {rel_codes.shape}. '
133                             f'Expected shape (batch, n_boxes, 4).')
134        if len(scores.shape) != 3:
135            raise ValueError(f'Invalid input scores shape {scores.shape}. '
136                             f'Expected shape (batch, n_boxes, n_labels).')
137        if rel_codes.shape[-2] != scores.shape[-2]:
138            raise ValueError(f'Mismatch in the number of boxes between input codes ({rel_codes.shape[-2]}) '
139                             f'and input scores ({scores.shape[-2]}).')
140
141        if self.cfg.score_converter != ScoreConverter.LINEAR:
142            scores = tf.keras.layers.Activation(self.cfg.score_converter)(scores)
143
144        if self.cfg.remove_background:
145            scores = tf.slice(scores, begin=[0, 0, 1], size=[-1, -1, -1])
146
147        decoded_boxes = self._box_decode(rel_codes)
148        # when decoded_boxes.shape[-2]==1, nms uses same boxes for all classes
149        decoded_boxes = tf.expand_dims(decoded_boxes, axis=-2)
150
151        outputs = tf.image.combined_non_max_suppression(decoded_boxes,
152                                                        scores,
153                                                        max_output_size_per_class=self.cfg.max_detections,
154                                                        max_total_size=self.cfg.max_detections,
155                                                        iou_threshold=self.cfg.iou_threshold,
156                                                        score_threshold=self.cfg.score_threshold,
157                                                        pad_per_class=False,
158                                                        clip_boxes=False)
159        return outputs
160
161    def get_config(self) -> dict:
162        """ """
163        config = super().get_config()
164        d = self.cfg.as_dict()
165        d['anchors'] = tf.constant(d['anchors']).numpy().tolist()
166        config.update(d)
167        return config

SSD Post Processing, based on https://arxiv.org/abs/1512.02325.

Arguments:
  • anchors (Tensor | np.ndarray): Anchors with a shape of (n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).
  • scale_factors (list | tuple): Box decoding scaling factors in the format (y, x, height, width).
  • clip_size (list | tuple): Clipping size in the format (height, width). The decoded boxes are clipped to the range y=[0, height] and x=[0, width]. Typically, the clipping size is (1, 1) for normalized boxes and the image size for boxes in pixel coordinates.
  • score_converter (ScoreConverter): Conversion to apply to the input logits (sigmoid, softmax, or linear).
  • score_threshold (float): Score threshold for non-maximum suppression.
  • iou_threshold (float): Intersection over union threshold for non-maximum suppression.
  • max_detections (int): The number of detections to return.
  • remove_background (bool) : If True, the first class is removed from the input scores (after the score_converter is applied).
Inputs:

A list or tuple of:

  • rel_codes (Tensor): Relative codes (encoded offsets) with a shape of (batch, n_boxes, 4) in centroid coordinates (y_center, x_center, w, h).
  • scores (Tensor): Scores or logits with a shape of (batch, n_boxes, n_labels).
Returns:

'CombinedNonMaxSuppression' named tuple:

  • nmsed_boxes: Selected boxes sorted by scores in descending order, with a shape of (batch, max_detections, 4),in corner coordinates (y_min, x_min, y_max, x_max).
  • nmsed_scores: Scores corresponding to the selected boxes, with a shape of (batch, max_detections).
  • nmsed_classes: Labels corresponding to the selected boxes, with a shape of (batch, max_detections). Each label corresponds to the class index of the selected score in the input scores.
  • valid_detections: The number of valid detections out of max_detections.
Raises:
  • ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes.
Example:
from sony_custom_layers.keras import SSDPostProcessing, ScoreConverter

post_process = SSDPostProcess(anchors=anchors,
                              scale_factors=(10, 10, 5, 5),
                              clip_size=(320, 320),
                              score_converter=ScoreConverter.SIGMOID,
                              score_threshold=0.01,
                              iou_threshold=0.6,
                              max_detections=200,
                              remove_background=True)
res = post_process([rel_codes, logits])
boxes = res.nmsed_boxes
def custom_layers_scope(*args: dict):
22def custom_layers_scope(*args: dict):
23    """
24    Scope context manager that can be used to deserialize Keras models containing custom layers
25
26    If the model contains custom layers only from this package:
27    ```
28    from sony_custom_layers.keras import custom_layers_scope
29    with custom_layers_scope():
30        tf.keras.models.load_model(path)
31    ```
32    If the model contains additional custom layers from other sources, there are two ways:
33    1. Pass a list of dictionaries {layer_name: layer_object} as *args.
34        ```
35        with custom_layers_scope({'Op1': Op1, 'Op2': Op2}, {'Op3': Op3}):
36            tf.keras.models.load_model(path)
37        ```
38    2. Combined with other scopes based on tf.keras.utils.custom_object_scope:
39        ```
40        with custom_layers_scope(), another_scope():
41            tf.keras.models.load_model(path)
42
43        # or:
44
45        with custom_layers_scope():
46            with another_scope():
47                tf.keras.models.load_model(path)
48        ```
49    Args:
50        *args: a list of dictionaries for other custom layers
51
52    Returns:
53        Scope context manager
54    """
55    return tf.keras.utils.custom_object_scope(*args + (_custom_objects, ))

Scope context manager that can be used to deserialize Keras models containing custom layers

If the model contains custom layers only from this package:

from sony_custom_layers.keras import custom_layers_scope
with custom_layers_scope():
    tf.keras.models.load_model(path)

If the model contains additional custom layers from other sources, there are two ways:

  1. Pass a list of dictionaries {layer_name: layer_object} as *args.
with custom_layers_scope({'Op1': Op1, 'Op2': Op2}, {'Op3': Op3}):
    tf.keras.models.load_model(path)
  1. Combined with other scopes based on tf.keras.utils.custom_object_scope:

    with custom_layers_scope(), another_scope():
        tf.keras.models.load_model(path)
    
    # or:
    
    with custom_layers_scope():
        with another_scope():
            tf.keras.models.load_model(path)
    
Arguments:
  • *args: a list of dictionaries for other custom layers
Returns:

Scope context manager