sony_custom_layers.pytorch
1# ----------------------------------------------------------------------------- 2# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ----------------------------------------------------------------------------- 16from typing import Optional, TYPE_CHECKING 17 18from sony_custom_layers.util.import_util import validate_installed_libraries 19from sony_custom_layers import required_libraries 20 21if TYPE_CHECKING: 22 import onnxruntime as ort 23 24__all__ = [ 25 'multiclass_nms', 'NMSResults', 'multiclass_nms_with_indices', 'NMSWithIndicesResults', 'FasterRCNNBoxDecode', 26 'load_custom_ops' 27] 28 29validate_installed_libraries(required_libraries['torch']) 30from sony_custom_layers.pytorch.nms import ( # noqa: E402 31 multiclass_nms, NMSResults, multiclass_nms_with_indices, NMSWithIndicesResults) 32from sony_custom_layers.pytorch.box_decode import FasterRCNNBoxDecode # noqa: E402 33 34 35def load_custom_ops(ort_session_ops: Optional['ort.SessionOptions'] = None) -> 'ort.SessionOptions': 36 """ 37 Registers the custom ops implementation for onnxruntime, and sets up the SessionOptions object for onnxruntime 38 session. 39 40 Args: 41 ort_session_ops: SessionOptions object to register the custom ops library on. If None, creates a new object. 42 43 Returns: 44 SessionOptions object with registered custom ops. 45 46 Example: 47 ``` 48 import onnxruntime as ort 49 from sony_custom_layers.pytorch import load_custom_ops 50 51 so = load_custom_ops() 52 session = ort.InferenceSession(model_path, sess_options=so) 53 session.run(...) 54 ``` 55 You can also pass your own SessionOptions object upon which to register the custom ops 56 ``` 57 load_custom_ops(ort_session_options=so) 58 ``` 59 """ 60 validate_installed_libraries(required_libraries['torch_ort']) 61 62 # trigger onnxruntime op registration 63 from .nms import nms_ort 64 from .box_decode import box_decode_ort 65 66 from onnxruntime_extensions import get_library_path 67 from onnxruntime import SessionOptions 68 ort_session_ops = ort_session_ops or SessionOptions() 69 ort_session_ops.register_custom_ops_library(get_library_path()) 70 return ort_session_ops
53def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSResults: 54 """ 55 Multi-class non-maximum suppression. 56 Detections are returned in descending order of their scores. 57 The output tensors always contain a fixed number of detections, as defined by 'max_detections'. 58 If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'. 59 60 If you also require the input indices of the selected boxes, see `multiclass_nms_with_indices`. 61 62 Args: 63 boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates 64 (x_min, y_min, x_max, y_max). Agnostic to the x-y axes order. 65 scores (Tensor): Input scores with shape [batch, n_boxes, n_classes]. 66 score_threshold (float): The score threshold. Candidates with scores below the threshold are discarded. 67 iou_threshold (float): The Intersection Over Union (IOU) threshold for boxes overlap. 68 max_detections (int): The number of detections to return. 69 70 Returns: 71 'NMSResults' named tuple: 72 - boxes: The selected boxes with shape [batch, max_detections, 4]. 73 - scores: The corresponding scores in descending order with shape [batch, max_detections]. 74 - labels: The labels for each box with shape [batch, max_detections]. 75 - n_valid: The number of valid detections out of 'max_detections' with shape [batch, 1] 76 77 Raises: 78 ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes. 79 80 Example: 81 ``` 82 from sony_custom_layers.pytorch import multiclass_nms 83 84 # batch size=1, 1000 boxes, 50 classes 85 boxes = torch.rand(1, 1000, 4) 86 scores = torch.rand(1, 1000, 50) 87 res = multiclass_nms(boxes, 88 scores, 89 score_threshold=0.1, 90 iou_threshold=0.6, 91 max_detections=300) 92 # res.boxes, res.scores, res.labels, res.n_valid 93 ``` 94 """ 95 return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections))
Multi-class non-maximum suppression. Detections are returned in descending order of their scores. The output tensors always contain a fixed number of detections, as defined by 'max_detections'. If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.
If you also require the input indices of the selected boxes, see multiclass_nms_with_indices
.
Arguments:
- boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates (x_min, y_min, x_max, y_max). Agnostic to the x-y axes order.
- scores (Tensor): Input scores with shape [batch, n_boxes, n_classes].
- score_threshold (float): The score threshold. Candidates with scores below the threshold are discarded.
- iou_threshold (float): The Intersection Over Union (IOU) threshold for boxes overlap.
- max_detections (int): The number of detections to return.
Returns:
'NMSResults' named tuple:
- boxes: The selected boxes with shape [batch, max_detections, 4].
- scores: The corresponding scores in descending order with shape [batch, max_detections].
- labels: The labels for each box with shape [batch, max_detections].
- n_valid: The number of valid detections out of 'max_detections' with shape [batch, 1]
Raises:
- ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes.
Example:
from sony_custom_layers.pytorch import multiclass_nms # batch size=1, 1000 boxes, 50 classes boxes = torch.rand(1, 1000, 4) scores = torch.rand(1, 1000, 50) res = multiclass_nms(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=300) # res.boxes, res.scores, res.labels, res.n_valid
31class NMSResults(NamedTuple): 32 """ Container for non-maximum suppression results """ 33 boxes: Tensor 34 scores: Tensor 35 labels: Tensor 36 n_valid: Tensor 37 38 # Note: convenience methods below are replicated in each Results container, since NamedTuple supports neither adding 39 # new fields in derived classes nor multiple inheritance, and we want it to behave like a tuple, so no dataclasses. 40 def detach(self) -> 'NMSResults': 41 """ Detach all tensors and return a new object """ 42 return self.apply(lambda t: t.detach()) 43 44 def cpu(self) -> 'NMSResults': 45 """ Move all tensors to cpu and return a new object """ 46 return self.apply(lambda t: t.cpu()) 47 48 def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSResults': 49 """ Apply any function to all tensors and return a new object """ 50 return self.__class__(*[f(t) for t in self])
Container for non-maximum suppression results
Create new instance of NMSResults(boxes, scores, labels, n_valid)
40 def detach(self) -> 'NMSResults': 41 """ Detach all tensors and return a new object """ 42 return self.apply(lambda t: t.detach())
Detach all tensors and return a new object
44 def cpu(self) -> 'NMSResults': 45 """ Move all tensors to cpu and return a new object """ 46 return self.apply(lambda t: t.cpu())
Move all tensors to cpu and return a new object
53def multiclass_nms_with_indices(boxes, scores, score_threshold: float, iou_threshold: float, 54 max_detections: int) -> NMSWithIndicesResults: 55 """ 56 Multi-class non-maximum suppression with indices. 57 Detections are returned in descending order of their scores. 58 The output tensors always contain a fixed number of detections, as defined by 'max_detections'. 59 If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'. 60 61 This operator is identical to `multiclass_nms` except that is also outputs the input indices of the selected boxes. 62 63 Args: 64 boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates 65 (x_min, y_min, x_max, y_max). Agnostic to the x-y axes order. 66 scores (Tensor): Input scores with shape [batch, n_boxes, n_classes]. 67 score_threshold (float): The score threshold. Candidates with scores below the threshold are discarded. 68 iou_threshold (float): The Intersection Over Union (IOU) threshold for boxes overlap. 69 max_detections (int): The number of detections to return. 70 71 Returns: 72 'NMSWithIndicesResults' named tuple: 73 - boxes: The selected boxes with shape [batch, max_detections, 4]. 74 - scores: The corresponding scores in descending order with shape [batch, max_detections]. 75 - labels: The labels for each box with shape [batch, max_detections]. 76 - indices: Indices of the input boxes that have been selected. 77 - n_valid: The number of valid detections out of 'max_detections' with shape [batch, 1] 78 79 Raises: 80 ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes. 81 82 Example: 83 ``` 84 from sony_custom_layers.pytorch import multiclass_nms_with_indices 85 86 # batch size=1, 1000 boxes, 50 classes 87 boxes = torch.rand(1, 1000, 4) 88 scores = torch.rand(1, 1000, 50) 89 res = multiclass_nms_with_indices(boxes, 90 scores, 91 score_threshold=0.1, 92 iou_threshold=0.6, 93 max_detections=300) 94 # res.boxes, res.scores, res.labels, res.indices, res.n_valid 95 ``` 96 """ 97 return NMSWithIndicesResults( 98 *torch.ops.sony.multiclass_nms_with_indices(boxes, scores, score_threshold, iou_threshold, max_detections))
Multi-class non-maximum suppression with indices. Detections are returned in descending order of their scores. The output tensors always contain a fixed number of detections, as defined by 'max_detections'. If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.
This operator is identical to multiclass_nms
except that is also outputs the input indices of the selected boxes.
Arguments:
- boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates (x_min, y_min, x_max, y_max). Agnostic to the x-y axes order.
- scores (Tensor): Input scores with shape [batch, n_boxes, n_classes].
- score_threshold (float): The score threshold. Candidates with scores below the threshold are discarded.
- iou_threshold (float): The Intersection Over Union (IOU) threshold for boxes overlap.
- max_detections (int): The number of detections to return.
Returns:
'NMSWithIndicesResults' named tuple:
- boxes: The selected boxes with shape [batch, max_detections, 4].
- scores: The corresponding scores in descending order with shape [batch, max_detections].
- labels: The labels for each box with shape [batch, max_detections].
- indices: Indices of the input boxes that have been selected.
- n_valid: The number of valid detections out of 'max_detections' with shape [batch, 1]
Raises:
- ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes.
Example:
from sony_custom_layers.pytorch import multiclass_nms_with_indices # batch size=1, 1000 boxes, 50 classes boxes = torch.rand(1, 1000, 4) scores = torch.rand(1, 1000, 50) res = multiclass_nms_with_indices(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=300) # res.boxes, res.scores, res.labels, res.indices, res.n_valid
30class NMSWithIndicesResults(NamedTuple): 31 """ Container for non-maximum suppression with indices results """ 32 boxes: Tensor 33 scores: Tensor 34 labels: Tensor 35 indices: Tensor 36 n_valid: Tensor 37 38 # Note: convenience methods below are replicated in each Results container, since NamedTuple supports neither adding 39 # new fields in derived classes nor multiple inheritance, and we want it to behave like a tuple, so no dataclasses. 40 def detach(self) -> 'NMSWithIndicesResults': 41 """ Detach all tensors and return a new object """ 42 return self.apply(lambda t: t.detach()) 43 44 def cpu(self) -> 'NMSWithIndicesResults': 45 """ Move all tensors to cpu and return a new object """ 46 return self.apply(lambda t: t.cpu()) 47 48 def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSWithIndicesResults': 49 """ Apply any function to all tensors and return a new object """ 50 return self.__class__(*[f(t) for t in self])
Container for non-maximum suppression with indices results
Create new instance of NMSWithIndicesResults(boxes, scores, labels, indices, n_valid)
40 def detach(self) -> 'NMSWithIndicesResults': 41 """ Detach all tensors and return a new object """ 42 return self.apply(lambda t: t.detach())
Detach all tensors and return a new object
44 def cpu(self) -> 'NMSWithIndicesResults': 45 """ Move all tensors to cpu and return a new object """ 46 return self.apply(lambda t: t.cpu())
Move all tensors to cpu and return a new object
30class FasterRCNNBoxDecode(nn.Module): 31 """ 32 Box decoding as per Faster R-CNN <https://arxiv.org/abs/1506.01497>. 33 34 Args: 35 anchors: Anchors with a shape of (n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max). 36 scale_factors: Scaling factors in the format (y, x, height, width). 37 clip_window: Clipping window in the format (y_min, x_min, y_max, x_max). 38 39 Inputs: 40 **rel_codes** (Tensor): Relative codes (encoded offsets) with a shape of (batch, n_boxes, 4) in centroid 41 coordinates (y_center, x_center, h, w). 42 43 Returns: 44 Decoded boxes with a shape of (batch, n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max). 45 46 Raises: 47 ValueError: If provided with invalid arguments or an input tensor with unexpected shape 48 49 Example: 50 ``` 51 from sony_custom_layers.pytorch import FasterRCNNBoxDecode 52 53 box_decode = FasterRCNNBoxDecode(anchors, 54 scale_factors=(10, 10, 5, 5), 55 clip_window=(0, 0, 1, 1)) 56 decoded_boxes = box_decode(rel_codes) 57 ``` 58 """ 59 60 def __init__(self, anchors: torch.Tensor, scale_factors: Sequence[Union[float, int]], 61 clip_window: Sequence[Union[float, int]]): 62 super().__init__() 63 if not (len(anchors.shape) == 2 and anchors.shape[-1] == 4): 64 raise ValueError(f'Invalid anchors shape {anchors.shape}. Expected shape (n_boxes, 4).') 65 self.register_buffer('anchors', anchors) 66 67 if len(scale_factors) != 4: 68 raise ValueError(f'Invalid scale factors {scale_factors}. Expected 4 values for (y, x, height, width).') 69 self.register_buffer('scale_factors', torch.tensor(scale_factors, dtype=torch.float32, device=anchors.device)) 70 71 if len(clip_window) != 4: 72 raise ValueError(f'Invalid clip window {clip_window}. Expected 4 values for (y_min, x_min, y_max, x_max).') 73 self.register_buffer('clip_window', torch.tensor(clip_window, dtype=torch.float32, device=anchors.device)) 74 75 def forward(self, rel_codes: torch.Tensor) -> torch.Tensor: 76 return torch.ops.sony.faster_rcnn_box_decode(rel_codes, self.anchors, self.scale_factors, self.clip_window)
Box decoding as per Faster R-CNN https://arxiv.org/abs/1506.01497.
Arguments:
- anchors: Anchors with a shape of (n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).
- scale_factors: Scaling factors in the format (y, x, height, width).
- clip_window: Clipping window in the format (y_min, x_min, y_max, x_max).
Inputs:
rel_codes (Tensor): Relative codes (encoded offsets) with a shape of (batch, n_boxes, 4) in centroid coordinates (y_center, x_center, h, w).
Returns:
Decoded boxes with a shape of (batch, n_boxes, 4) in corner coordinates (y_min, x_min, y_max, x_max).
Raises:
- ValueError: If provided with invalid arguments or an input tensor with unexpected shape
Example:
from sony_custom_layers.pytorch import FasterRCNNBoxDecode box_decode = FasterRCNNBoxDecode(anchors, scale_factors=(10, 10, 5, 5), clip_window=(0, 0, 1, 1)) decoded_boxes = box_decode(rel_codes)
60 def __init__(self, anchors: torch.Tensor, scale_factors: Sequence[Union[float, int]], 61 clip_window: Sequence[Union[float, int]]): 62 super().__init__() 63 if not (len(anchors.shape) == 2 and anchors.shape[-1] == 4): 64 raise ValueError(f'Invalid anchors shape {anchors.shape}. Expected shape (n_boxes, 4).') 65 self.register_buffer('anchors', anchors) 66 67 if len(scale_factors) != 4: 68 raise ValueError(f'Invalid scale factors {scale_factors}. Expected 4 values for (y, x, height, width).') 69 self.register_buffer('scale_factors', torch.tensor(scale_factors, dtype=torch.float32, device=anchors.device)) 70 71 if len(clip_window) != 4: 72 raise ValueError(f'Invalid clip window {clip_window}. Expected 4 values for (y_min, x_min, y_max, x_max).') 73 self.register_buffer('clip_window', torch.tensor(clip_window, dtype=torch.float32, device=anchors.device))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
75 def forward(self, rel_codes: torch.Tensor) -> torch.Tensor: 76 return torch.ops.sony.faster_rcnn_box_decode(rel_codes, self.anchors, self.scale_factors, self.clip_window)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
36def load_custom_ops(ort_session_ops: Optional['ort.SessionOptions'] = None) -> 'ort.SessionOptions': 37 """ 38 Registers the custom ops implementation for onnxruntime, and sets up the SessionOptions object for onnxruntime 39 session. 40 41 Args: 42 ort_session_ops: SessionOptions object to register the custom ops library on. If None, creates a new object. 43 44 Returns: 45 SessionOptions object with registered custom ops. 46 47 Example: 48 ``` 49 import onnxruntime as ort 50 from sony_custom_layers.pytorch import load_custom_ops 51 52 so = load_custom_ops() 53 session = ort.InferenceSession(model_path, sess_options=so) 54 session.run(...) 55 ``` 56 You can also pass your own SessionOptions object upon which to register the custom ops 57 ``` 58 load_custom_ops(ort_session_options=so) 59 ``` 60 """ 61 validate_installed_libraries(required_libraries['torch_ort']) 62 63 # trigger onnxruntime op registration 64 from .nms import nms_ort 65 from .box_decode import box_decode_ort 66 67 from onnxruntime_extensions import get_library_path 68 from onnxruntime import SessionOptions 69 ort_session_ops = ort_session_ops or SessionOptions() 70 ort_session_ops.register_custom_ops_library(get_library_path()) 71 return ort_session_ops
Registers the custom ops implementation for onnxruntime, and sets up the SessionOptions object for onnxruntime session.
Arguments:
- ort_session_ops: SessionOptions object to register the custom ops library on. If None, creates a new object.
Returns:
SessionOptions object with registered custom ops.
Example:
import onnxruntime as ort from sony_custom_layers.pytorch import load_custom_ops so = load_custom_ops() session = ort.InferenceSession(model_path, sess_options=so) session.run(...)
You can also pass your own SessionOptions object upon which to register the custom ops
load_custom_ops(ort_session_options=so)