# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import torch
from ocnn.octree import Octree
[docs]def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0):
r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to
the octree node number.
Args:
data (torch.Tensor): The input tensor with its number of elements equal to the
non-empty octree node number.
octree (Octree): The corresponding octree.
depth (int): The depth of current octree.
val (float): The padding value. (Default: :obj:`0.0`)
'''
idx = octree.nempty_index(depth)
size = (octree.nnum[depth], data.shape[1]) # (N, C)
out = torch.full(size, val, dtype=data.dtype, device=data.device)
out[idx] = data
return out
[docs]def octree_depad(data: torch.Tensor, octree: Octree, depth: int):
r''' Reverse operation of :func:`octree_depad`.
Please refer to :func:`octree_depad` for the meaning of the arguments.
'''
idx = octree.nempty_index(depth)
return data[idx]