Source code for ocnn.nn.octree2vox

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch

from ocnn.octree import Octree


[docs]def octree2voxel(data: torch.Tensor, octree: Octree, depth: int, nempty: bool = False): r''' Converts the input feature to the full-voxel-based representation. Args: data (torch.Tensor): The input feature. octree (Octree): The corresponding octree. depth (int): The depth of current octree. nempty (bool): If True, :attr:`data` only contains the features of non-empty octree nodes. ''' x, y, z, b = octree.xyzb(depth, nempty) num = 1 << depth channel = data.shape[1] vox = data.new_zeros([octree.batch_size, num, num, num, channel]) vox[b, x, y, z] = data return vox
[docs]class Octree2Voxel(torch.nn.Module): r''' Converts the input feature to the full-voxel-based representation. Please refer to :func:`octree2voxel` for details. ''' def __init__(self, nempty: bool = False): super().__init__() self.nempty = nempty
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int): r'''''' return octree2voxel(data, octree, depth, self.nempty)
def extra_repr(self) -> str: return 'nempty={}'.format(self.nempty)