# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import torch
import torch.nn
from typing import Dict
import ocnn
from ocnn.octree import Octree
[docs]class UNet(torch.nn.Module):
r''' Octree-based UNet for segmentation.
'''
def __init__(self, in_channels: int, out_channels: int, interp: str = 'linear',
nempty: bool = False, **kwargs):
super(UNet, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.nempty = nempty
self.config_network()
self.encoder_stages = len(self.encoder_blocks)
self.decoder_stages = len(self.decoder_blocks)
# encoder
self.conv1 = ocnn.modules.OctreeConvBnRelu(
in_channels, self.encoder_channel[0], nempty=nempty)
self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu(
self.encoder_channel[i], self.encoder_channel[i+1], kernel_size=[2],
stride=2, nempty=nempty) for i in range(self.encoder_stages)])
self.encoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
self.encoder_channel[i+1], self.encoder_channel[i + 1],
self.encoder_blocks[i], self.bottleneck, nempty, self.resblk)
for i in range(self.encoder_stages)])
# decoder
channel = [self.decoder_channel[i+1] + self.encoder_channel[-i-2]
for i in range(self.decoder_stages)]
self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu(
self.decoder_channel[i], self.decoder_channel[i+1], kernel_size=[2],
stride=2, nempty=nempty) for i in range(self.decoder_stages)])
self.decoder = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks(
channel[i], self.decoder_channel[i+1],
self.decoder_blocks[i], self.bottleneck, nempty, self.resblk)
for i in range(self.decoder_stages)])
# header
# channel = self.decoder_channel[self.decoder_stages]
self.octree_interp = ocnn.nn.OctreeInterp(interp, nempty)
self.header = torch.nn.Sequential(
ocnn.modules.Conv1x1BnRelu(self.decoder_channel[-1], self.head_channel),
ocnn.modules.Conv1x1(self.head_channel, self.out_channels, use_bias=True))
[docs] def config_network(self):
r''' Configure the network channels and Resblock numbers.
'''
self.encoder_channel = [32, 32, 64, 128, 256]
self.decoder_channel = [256, 256, 128, 96, 96]
self.encoder_blocks = [2, 3, 4, 6]
self.decoder_blocks = [2, 2, 2, 2]
self.head_channel = 64
self.bottleneck = 1
self.resblk = ocnn.modules.OctreeResBlock2
[docs] def unet_encoder(self, data: torch.Tensor, octree: Octree, depth: int):
r''' The encoder of the U-Net.
'''
convd = dict()
convd[depth] = self.conv1(data, octree, depth)
for i in range(self.encoder_stages):
d = depth - i
conv = self.downsample[i](convd[d], octree, d)
convd[d-1] = self.encoder[i](conv, octree, d-1)
return convd
[docs] def unet_decoder(self, convd: Dict[int, torch.Tensor], octree: Octree, depth: int):
r''' The decoder of the U-Net.
'''
deconv = convd[depth]
for i in range(self.decoder_stages):
d = depth + i
deconv = self.upsample[i](deconv, octree, d)
deconv = torch.cat([convd[d+1], deconv], dim=1) # skip connections
deconv = self.decoder[i](deconv, octree, d+1)
return deconv
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int,
query_pts: torch.Tensor):
r''''''
convd = self.unet_encoder(data, octree, depth)
deconv = self.unet_decoder(convd, octree, depth - self.encoder_stages)
interp_depth = depth - self.encoder_stages + self.decoder_stages
feature = self.octree_interp(deconv, octree, interp_depth, query_pts)
logits = self.header(feature)
return logits