Source code for ocnn.models.autoencoder

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch
import torch.nn

import ocnn
from ocnn.octree import Octree


[docs]class AutoEncoder(torch.nn.Module): r''' Octree-based AutoEncoder for shape encoding and decoding. Args: channel_in (int): The channel of the input signal. channel_out (int): The channel of the output signal. depth (int): The depth of the octree. full_depth (int): The full depth of the octree. feature (str): The feature type of the input signal. For details of this argument, please refer to :class:`ocnn.modules.InputFeature`. ''' def __init__(self, channel_in: int, channel_out: int, depth: int, full_depth: int = 2, feature: str = 'ND'): super().__init__() self.channel_in = channel_in self.channel_out = channel_out self.depth = depth self.full_depth = full_depth self.feature = feature self.resblk_num = 2 self.channels = [512, 512, 256, 256, 128, 128, 32, 32, 16, 16] # dim-of-code = code_channel * 2**(3*full_depth) self.code_channel = self.channels[full_depth] # encoder self.conv1 = ocnn.modules.OctreeConvBnRelu( channel_in, self.channels[depth], nempty=False) self.encoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks( self.channels[d], self.channels[d], self.resblk_num, nempty=False) for d in range(depth, full_depth-1, -1)]) self.downsample = torch.nn.ModuleList([ocnn.modules.OctreeConvBnRelu( self.channels[d], self.channels[d-1], kernel_size=[2], stride=2, nempty=False) for d in range(depth, full_depth, -1)]) self.proj = torch.nn.Linear( self.channels[full_depth], self.code_channel, bias=True) # decoder self.channels[full_depth] = self.code_channel # update `channels` self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu( self.channels[d-1], self.channels[d], kernel_size=[2], stride=2, nempty=False) for d in range(full_depth+1, depth+1)]) self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks( self.channels[d], self.channels[d], self.resblk_num, nempty=False) for d in range(full_depth, depth+1)]) # header self.predict = torch.nn.ModuleList([self._make_predict_module( self.channels[d], 2) for d in range(full_depth, depth + 1)]) self.header = self._make_predict_module(self.channels[depth], channel_out) def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64): return torch.nn.Sequential( ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden), ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True))
[docs] def encoder(self, octree: Octree): r''' The encoder network of the AutoEncoder. ''' convs = dict() depth, full_depth = self.depth, self.full_depth data = octree.get_input_feature(self.feature, nempty=False) assert data.size(1) == self.channel_in convs[depth] = self.conv1(data, octree, depth) for i, d in enumerate(range(depth, full_depth-1, -1)): convs[d] = self.encoder_blks[i](convs[d], octree, d) if d > full_depth: convs[d-1] = self.downsample[i](convs[d], octree, d) # NOTE: here tanh is used to constrain the shape code in [-1, 1] shape_code = self.proj(convs[full_depth]).tanh() return shape_code
[docs] def decoder(self, shape_code: torch.Tensor, octree: Octree, update_octree: bool = False): r''' The decoder network of the AutoEncoder. ''' logits = dict() deconv = shape_code depth, full_depth = self.depth, self.full_depth for i, d in enumerate(range(full_depth, depth+1)): if d > full_depth: deconv = self.upsample[i-1](deconv, octree, d-1) deconv = self.decoder_blks[i](deconv, octree, d) # predict the splitting label logit = self.predict[i](deconv) logits[d] = logit # update the octree according to predicted labels if update_octree: split = logit.argmax(1).int() octree.octree_split(split, d) if d < depth: octree.octree_grow(d + 1) # predict the signal if d == depth: signal = self.header(deconv) signal = torch.tanh(signal) signal = ocnn.nn.octree_depad(signal, octree, depth) if update_octree: octree.features[depth] = signal return {'logits': logits, 'signal': signal, 'octree_out': octree}
[docs] def decode_code(self, shape_code: torch.Tensor): r''' Decodes the shape code to an output octree. Args: shape_code (torch.Tensor): The shape code for decoding. ''' octree_out = self.init_octree(shape_code) out = self.decoder(shape_code, octree_out, update_octree=True) return out
[docs] def init_octree(self, shape_code: torch.Tensor): r''' Initialize a full octree for decoding. Args: shape_code (torch.Tensor): The shape code for decoding, used to get the `batch_size` and `device` to initialize the output octree. ''' node_num = 2 ** (3 * self.full_depth) batch_size = shape_code.size(0) // node_num octree = ocnn.octree.init_octree( self.depth, self.full_depth, batch_size, shape_code.device) return octree
[docs] def forward(self, octree: Octree, update_octree: bool): r'''''' shape_code = self.encoder(octree) if update_octree: octree = self.init_octree(shape_code) out = self.decoder(shape_code, octree, update_octree) return out