Source code for ocnn.models.image2shape

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch
import torch.nn
from typing import Optional
from torchvision.models import resnet18

import ocnn
from ocnn.octree import Octree


[docs]class Image2Shape(torch.nn.Module): r''' Octree-based AutoEncoder for shape encoding and decoding. Args: channel_out (int): The channel of the output signal. depth (int): The depth of the octree. full_depth (int): The full depth of the octree. ''' def __init__(self, channel_out: int, depth: int, full_depth: int = 2, code_channel: int = 32): super().__init__() self.depth = depth self.full_depth = full_depth self.channel_out = channel_out self.resblk_num = 2 self.channels = [512, 512, 256, 256, 128, 128, 64, 64, 32, 32] self.code_channel = code_channel # encoder self.resnet18 = resnet18() channel = self.code_channel * 2 ** (3 * full_depth) self.resnet18.fc = torch.nn.Linear(512, channel, bias=True) # decoder self.channels[full_depth] = self.code_channel # update `channels` self.upsample = torch.nn.ModuleList([ocnn.modules.OctreeDeconvBnRelu( self.channels[d-1], self.channels[d], kernel_size=[2], stride=2, nempty=False) for d in range(full_depth+1, depth+1)]) self.decoder_blks = torch.nn.ModuleList([ocnn.modules.OctreeResBlocks( self.channels[d], self.channels[d], self.resblk_num, nempty=False) for d in range(full_depth, depth+1)]) # header self.predict = torch.nn.ModuleList([self._make_predict_module( self.channels[d], 2) for d in range(full_depth, depth + 1)]) self.header = self._make_predict_module(self.channels[depth], channel_out) def _make_predict_module(self, channel_in, channel_out=2, num_hidden=64): return torch.nn.Sequential( ocnn.modules.Conv1x1BnRelu(channel_in, num_hidden), ocnn.modules.Conv1x1(num_hidden, channel_out, use_bias=True))
[docs] def decoder(self, shape_code: torch.Tensor, octree: Octree, update_octree: bool = False): r''' The decoder network of the AutoEncoder. ''' logits = dict() deconv = shape_code depth, full_depth = self.depth, self.full_depth for i, d in enumerate(range(full_depth, depth+1)): if d > full_depth: deconv = self.upsample[i-1](deconv, octree, d-1) deconv = self.decoder_blks[i](deconv, octree, d) # predict the splitting label logit = self.predict[i](deconv) logits[d] = logit # update the octree according to predicted labels if update_octree: split = logit.argmax(1).int() octree.octree_split(split, d) if d < depth: octree.octree_grow(d + 1) # predict the signal if d == depth: signal = self.header(deconv) signal = torch.tanh(signal) signal = ocnn.nn.octree_depad(signal, octree, depth) if update_octree: octree.features[depth] = signal return {'logits': logits, 'signal': signal, 'octree_out': octree}
[docs] def decode_code(self, shape_code: torch.Tensor): r''' Decodes the shape code to an output octree. Args: shape_code (torch.Tensor): The shape code for decoding. ''' octree_out = self.init_octree(shape_code) out = self.decoder(shape_code, octree_out, update_octree=True) return out
[docs] def init_octree(self, shape_code: torch.Tensor): r''' Initialize a full octree for decoding. Args: shape_code (torch.Tensor): The shape code for decoding, used to getting the `batch_size` and `device` to initialize the output octree. ''' node_num = 2 ** (3 * self.full_depth) batch_size = shape_code.size(0) // node_num octree = ocnn.octree.init_octree( self.depth, self.full_depth, batch_size, shape_code.device) return octree
[docs] def forward(self, image: torch.Tensor, octree: Optional[Octree] = None, update_octree: bool = False): r'''''' shape_code = self.resnet18(image) shape_code = shape_code.view(-1, self.code_channel) if update_octree: octree = self.init_octree(shape_code) out = self.decoder(shape_code, octree, update_octree) return out