Source code for ocnn.nn.octree_pool

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch
import torch.nn
from typing import List

from ocnn.octree import Octree
from ocnn.utils import meshgrid, scatter_add, resize_with_last_val, list2str
from . import octree_pad, octree_depad


[docs]def octree_max_pool(data: torch.Tensor, octree: Octree, depth: int, nempty: bool = False, return_indices: bool = False): r''' Performs octree max pooling with kernel size 2 and stride 2. Args: data (torch.Tensor): The input tensor. octree (Octree): The corresponding octree. depth (int): The depth of current octree. After pooling, the corresponding depth decreased by 1. nempty (bool): If True, :attr:`data` contains only features of non-empty octree nodes. return_indices (bool): If True, returns the indices, which can be used in :func:`octree_max_unpool`. ''' if nempty: data = octree_pad(data, octree, depth, float('-inf')) data = data.view(-1, 8, data.shape[1]) out, indices = data.max(dim=1) if not nempty: out = octree_pad(out, octree, depth-1) return (out, indices) if return_indices else out
[docs]def octree_max_unpool(data: torch.Tensor, indices: torch.Tensor, octree: Octree, depth: int, nempty: bool = False): r''' Performs octree max unpooling. Args: data (torch.Tensor): The input tensor. indices (torch.Tensor): The indices returned by :func:`octree_max_pool`. The depth of :attr:`indices` is larger by 1 than :attr:`data`. octree (Octree): The corresponding octree. depth (int): The depth of current data. After unpooling, the corresponding depth increases by 1. ''' if not nempty: data = octree_depad(data, octree, depth) num, channel = data.shape out = torch.zeros(num, 8, channel, dtype=data.dtype, device=data.device) i = torch.arange(num, dtype=indices.dtype, device=indices.device) k = torch.arange(channel, dtype=indices.dtype, device=indices.device) i, k = meshgrid(i, k, indexing='ij') out[i, indices, k] = data out = out.view(-1, channel) if nempty: out = octree_depad(out, octree, depth+1) return out
[docs]def octree_avg_pool(data: torch.Tensor, octree: Octree, depth: int, kernel: str, stride: int = 2, nempty: bool = False): r''' Performs octree average pooling. Args: data (torch.Tensor): The input tensor. octree (Octree): The corresponding octree. depth (int): The depth of current octree. kernel (str): The kernel size, like '333', '222'. stride (int): The stride of the pooling. nempty (bool): If True, :attr:`data` contains only features of non-empty octree nodes. ''' neigh = octree.get_neigh(depth, kernel, stride, nempty) N1 = data.shape[0] N2 = neigh.shape[0] K = neigh.shape[1] mask = neigh >= 0 val = 1.0 / (torch.sum(mask, dim=1) + 1e-8) mask = mask.view(-1) val = val.unsqueeze(1).repeat(1, K).reshape(-1) val = val[mask] row = torch.arange(N2, device=neigh.device) row = row.unsqueeze(1).repeat(1, K).view(-1) col = neigh.view(-1) indices = torch.stack([row[mask], col[mask]], dim=0).long() mat = torch.sparse_coo_tensor(indices, val, [N2, N1], device=data.device) out = torch.sparse.mm(mat, data) return out
[docs]def octree_global_pool(data: torch.Tensor, octree: Octree, depth: int, nempty: bool = False): r''' Performs octree global average pooling. Args: data (torch.Tensor): The input tensor. octree (Octree): The corresponding octree. depth (int): The depth of current octree. nempty (bool): If True, :attr:`data` contains only features of non-empty octree nodes. ''' batch_size = octree.batch_size batch_id = octree.batch_id(depth, nempty) ones = data.new_ones(data.shape[0], 1) count = scatter_add(ones, batch_id, dim=0, dim_size=batch_size) count[count < 1] = 1 # there might be 0 element in some shapes out = scatter_add(data, batch_id, dim=0, dim_size=batch_size) out = out / count return out
class OctreePoolBase(torch.nn.Module): r''' The base class for octree-based pooling. ''' def __init__(self, kernel_size: List[int], stride: int, nempty: bool = False): super().__init__() self.kernel_size = resize_with_last_val(kernel_size) self.kernel = list2str(self.kernel_size) self.stride = stride self.nempty = nempty def extra_repr(self) -> str: return ('kernel_size={}, stride={}, nempty={}').format( self.kernel_size, self.stride, self.nempty) # noqa
[docs]class OctreeMaxPool(OctreePoolBase): r''' Performs octree max pooling. Please refer to :func:`octree_max_pool` for details. ''' def __init__(self, nempty: bool = False, return_indices: bool = False): super().__init__(kernel_size=[2], stride=2, nempty=nempty) self.return_indices = return_indices
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int): r'''''' return octree_max_pool(data, octree, depth, self.nempty, self.return_indices)
[docs]class OctreeMaxUnpool(OctreePoolBase): r''' Performs octree max unpooling. Please refer to :func:`octree_max_unpool` for details. ''' def __init__(self, nempty: bool = False): super().__init__(kernel_size=[2], stride=2, nempty=nempty)
[docs] def forward(self, data: torch.Tensor, indices: torch.Tensor, octree: Octree, depth: int): r'''''' return octree_max_unpool(data, indices, octree, depth, self.nempty)
[docs]class OctreeGlobalPool(OctreePoolBase): r''' Performs octree global pooling. Please refer to :func:`octree_global_pool` for details. ''' def __init__(self, nempty: bool = False): super().__init__(kernel_size=[-1], stride=-1, nempty=nempty)
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int): r'''''' return octree_global_pool(data, octree, depth, self.nempty)
[docs]class OctreeAvgPool(OctreePoolBase): r''' Performs octree average pooling. Please refer to :func:`octree_avg_pool` for details. '''
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int): r'''''' return octree_avg_pool( data, octree, depth, self.kernel, self.stride, self.nempty)