Source code for ocnn.nn.octree_norm

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch
import torch.nn
from typing import Optional

from ocnn.octree import Octree
from ocnn.utils import scatter_add


OctreeBatchNorm = torch.nn.BatchNorm1d


[docs]class OctreeGroupNorm(torch.nn.Module): r''' An group normalization layer for the octree. ''' def __init__(self, in_channels: int, group: int, nempty: bool = False, min_group_channels: int = 4): super().__init__() self.eps = 1e-5 self.nempty = nempty self.group = group self.in_channels = in_channels self.min_group_channels = min_group_channels if self.min_group_channels * self.group > in_channels: self.group = in_channels // self.min_group_channels assert in_channels % self.group == 0 self.channels_per_group = in_channels // self.group self.weights = torch.nn.Parameter(torch.Tensor(1, in_channels)) self.bias = torch.nn.Parameter(torch.Tensor(1, in_channels)) self.reset_parameters()
[docs] def reset_parameters(self): torch.nn.init.ones_(self.weights) torch.nn.init.zeros_(self.bias)
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int): r'''''' batch_size = octree.batch_size batch_id = octree.batch_id(depth, self.nempty) ones = data.new_ones([data.shape[0], 1]) count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size) count = count * self.channels_per_group # element number in each group inv_count = 1.0 / (count + self.eps) # there might be 0 element sometimes mean = scatter_add(data, batch_id, dim=0, dim_size=batch_size) * inv_count mean = self._adjust_for_group(mean) out = data - mean.index_select(0, batch_id) var = scatter_add(out**2, batch_id, dim=0, dim_size=batch_size) * inv_count var = self._adjust_for_group(var) inv_std = 1.0 / (var + self.eps).sqrt() out = out * inv_std.index_select(0, batch_id) out = out * self.weights + self.bias return out
def _adjust_for_group(self, tensor: torch.Tensor): r''' Adjust the tensor for the group. ''' if self.channels_per_group > 1: tensor = (tensor.reshape(-1, self.group, self.channels_per_group) .sum(-1, keepdim=True) .repeat(1, 1, self.channels_per_group) .reshape(-1, self.in_channels)) return tensor def extra_repr(self) -> str: return ('in_channels={}, group={}, nempty={}, min_group_channels={}').format( self.in_channels, self.group, self.nempty, self.min_group_channels)
[docs]class OctreeInstanceNorm(OctreeGroupNorm): r''' An instance normalization layer for the octree. ''' def __init__(self, in_channels: int, nempty: bool = False): super().__init__(in_channels=in_channels, group=in_channels, nempty=nempty, min_group_channels=1) # NOTE: group=in_channels def extra_repr(self) -> str: return ('in_channels={}, nempty={}').format(self.in_channels, self.nempty)
[docs]class OctreeNorm(torch.nn.Module): r''' A normalization layer for the octree. It encapsulates octree-based batch, group and instance normalization. ''' def __init__(self, in_channels: int, norm_type: str = 'batch_norm', group: int = 32, min_group_channels: int = 4): super().__init__() self.in_channels = in_channels self.norm_type = norm_type self.group = group self.min_group_channels = min_group_channels if self.norm_type == 'batch_norm': self.norm = torch.nn.BatchNorm1d(in_channels) elif self.norm_type == 'group_norm': self.norm = OctreeGroupNorm(in_channels, group, min_group_channels) elif self.norm_type == 'instance_norm': self.norm = OctreeInstanceNorm(in_channels) else: raise ValueError
[docs] def forward(self, x: torch.Tensor, octree: Optional[Octree] = None, depth: Optional[int] = None): if self.norm_type == 'batch_norm': output = self.norm(x) elif (self.norm_type == 'group_norm' or self.norm_type == 'instance_norm'): output = self.norm(x, octree, depth) else: raise ValueError return output