Source code for ocnn.models.ounet

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------


import ocnn
import torch
import torch.nn

from ocnn.octree import Octree
from ocnn.models.autoencoder import AutoEncoder


[docs]class OUNet(AutoEncoder): def __init__(self, channel_in: int, channel_out: int, depth: int, full_depth: int = 2, feature: str = 'ND'): super().__init__(channel_in, channel_out, depth, full_depth, feature) self.proj = None # remove this module used in AutoEncoder
[docs] def encoder(self, octree): r''' The encoder network for extracting heirarchy features. ''' convs = dict() depth, full_depth = self.depth, self.full_depth data = octree.get_input_feature(self.feature, nempty=False) assert data.size(1) == self.channel_in convs[depth] = self.conv1(data, octree, depth) for i, d in enumerate(range(depth, full_depth-1, -1)): convs[d] = self.encoder_blks[i](convs[d], octree, d) if d > full_depth: convs[d-1] = self.downsample[i](convs[d], octree, d) return convs
[docs] def decoder(self, convs: dict, octree_in: Octree, octree_out: Octree, update_octree: bool = False): r''' The decoder network for decode the octree. ''' logits = dict() deconv = convs[self.full_depth] depth, full_depth = self.depth, self.full_depth for i, d in enumerate(range(full_depth, depth + 1)): if d > full_depth: deconv = self.upsample[i-1](deconv, octree_out, d-1) skip = ocnn.nn.octree_align(convs[d], octree_in, octree_out, d) deconv = deconv + skip # output-guided skip connections deconv = self.decoder_blks[i](deconv, octree_out, d) # predict the splitting label logit = self.predict[i](deconv) logits[d] = logit # update the octree according to predicted labels if update_octree: split = logit.argmax(1).int() octree_out.octree_split(split, d) if d < depth: octree_out.octree_grow(d + 1) # predict the signal if d == depth: signal = self.header(deconv) signal = torch.tanh(signal) signal = ocnn.nn.octree_depad(signal, octree_out, depth) if update_octree: octree_out.features[depth] = signal return {'logits': logits, 'signal': signal, 'octree_out': octree_out}
[docs] def init_octree(self, octree_in: Octree): r''' Initialize a full octree for decoding. ''' device = octree_in.device batch_size = octree_in.batch_size octree = Octree(self.depth, self.full_depth, batch_size, device) for d in range(self.full_depth+1): octree.octree_grow_full(depth=d) return octree
[docs] def forward(self, octree_in, octree_out=None, update_octree: bool = False): r'''''' if octree_out is None: update_octree = True octree_out = self.init_octree(octree_in) convs = self.encoder(octree_in) out = self.decoder(convs, octree_in, octree_out, update_octree) return out