# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import torch
from typing import List
import ocnn
from ocnn.octree import Octree
class Branches(torch.nn.Module):
def __init__(self, channels: List[int], resblk_num: int, nempty: bool = False):
super().__init__()
self.channels = channels
self.resblk_num = resblk_num
bottlenecks = [4 if c < 256 else 8 for c in channels] # to save parameters
self.resblocks = torch.nn.ModuleList([
ocnn.modules.OctreeResBlocks(ch, ch, resblk_num, bnk, nempty=nempty)
for ch, bnk in zip(channels, bottlenecks)])
def forward(self, datas: List[torch.Tensor], octree: Octree, depth: int):
num = len(self.channels)
torch._assert(len(datas) == num, 'Error')
out = [None] * num
for i in range(num):
depth_i = depth - i
out[i] = self.resblocks[i](datas[i], octree, depth_i)
return out
class TransFunc(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, nempty: bool = False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.nempty = nempty
self.maxpool = ocnn.nn.OctreeMaxPool(nempty=nempty)
self.upsample = ocnn.nn.OctreeUpsample(method='nearest', nempty=nempty)
if in_channels != out_channels:
self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, out_channels)
def forward(self, data: torch.Tensor, octree: Octree,
in_depth: int, out_depth: int):
out = data
if in_depth > out_depth:
for d in range(in_depth, out_depth, -1):
out = self.maxpool(out, octree, d)
if self.in_channels != self.out_channels:
out = self.conv1x1(out)
if in_depth < out_depth:
if self.in_channels != self.out_channels:
out = self.conv1x1(out)
for d in range(in_depth, out_depth, 1):
out = self.upsample(out, octree, d)
return out
class Transitions(torch.nn.Module):
def __init__(self, channels: List[int], nempty: bool = False):
super().__init__()
self.channels = channels
self.nempty = nempty
num = len(self.channels)
self.trans_func = torch.nn.ModuleList()
for i in range(num - 1):
for j in range(num):
self.trans_func.append(TransFunc(channels[i], channels[j], nempty))
def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
num = len(self.channels)
features = [[None] * (num - 1) for _ in range(num)]
for i in range(num - 1):
for j in range(num):
k = i * num + j
in_depth = depth - i
out_depth = depth - j
features[j][i] = self.trans_func[k](
data[i], octree, in_depth, out_depth)
out = [None] * num
for j in range(num):
# In the original tensorflow implmentation, a relu is added after the sum.
out[j] = torch.stack(features[j], dim=0).sum(dim=0)
return out
class FrontLayer(torch.nn.Module):
def __init__(self, channels: List[int], nempty: bool = False):
super().__init__()
self.channels = channels
self.num = len(channels) - 1
self.nempty = nempty
self.conv = torch.nn.ModuleList([
ocnn.modules.OctreeConvBnRelu(channels[i], channels[i+1], nempty=nempty)
for i in range(self.num)])
self.maxpool = torch.nn.ModuleList([
ocnn.nn.OctreeMaxPool(nempty) for i in range(self.num - 1)])
def forward(self, data: torch.Tensor, octree: Octree, depth: int):
out = data
for i in range(self.num - 1):
depth_i = depth - i
out = self.conv[i](out, octree, depth_i)
out = self.maxpool[i](out, octree, depth_i)
out = self.conv[-1](out, octree, depth - self.num + 1)
return out
class ClsHeader(torch.nn.Module):
def __init__(self, channels: List[int], out_channels: int, nempty: bool = False):
super().__init__()
self.channels = channels
self.out_channels = out_channels
self.nempty = nempty
in_channels = int(torch.Tensor(channels).sum())
self.conv1x1 = ocnn.modules.Conv1x1BnRelu(in_channels, 1024)
self.global_pool = ocnn.nn.OctreeGlobalPool(nempty)
self.header = torch.nn.Sequential(
torch.nn.Flatten(start_dim=1),
torch.nn.Linear(1024, out_channels, bias=True))
# self.header = torch.nn.Sequential(
# ocnn.modules.FcBnRelu(512, 256),
# torch.nn.Dropout(p=0.5),
# torch.nn.Linear(256, out_channels))
def forward(self, data: List[torch.Tensor], octree: Octree, depth: int):
full_depth = 2
num = len(data)
outs = [x for x in data] # avoid modifying the input data
for i in range(num):
depth_i = depth - i
for d in range(depth_i, full_depth, -1):
outs[i] = ocnn.nn.octree_max_pool(outs[i], octree, d, self.nempty)
out = torch.cat(outs, dim=1)
out = self.conv1x1(out)
out = self.global_pool(out, octree, full_depth)
logit = self.header(out)
return logit
[docs]class HRNet(torch.nn.Module):
r''' Octree-based HRNet for classification and segmentation. '''
def __init__(self, in_channels: int, out_channels: int, stages: int = 3,
interp: str = 'linear', nempty: bool = False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.interp = interp
self.nempty = nempty
self.stages = stages
self.resblk_num = 3
self.channels = [128, 256, 512, 512]
self.front = FrontLayer([in_channels, 32, self.channels[0]], nempty)
self.branches = torch.nn.ModuleList([
Branches(self.channels[:i+1], self.resblk_num, nempty)
for i in range(stages)])
self.transitions = torch.nn.ModuleList([
Transitions(self.channels[:i+2], nempty)
for i in range(stages-1)])
self.cls_header = ClsHeader(self.channels[:stages], out_channels, nempty)
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int):
r''''''
convs = [self.front(data, octree, depth)]
depth = depth - 1 # the data is downsampled in `front`
for i in range(self.stages):
convs = self.branches[i](convs, octree, depth)
if i < self.stages - 1:
convs = self.transitions[i](convs, octree, depth)
logits = self.cls_header(convs, octree, depth)
return logits