Source code for ocnn.nn.octree_pad

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch

from ocnn.octree import Octree


[docs]def octree_pad(data: torch.Tensor, octree: Octree, depth: int, val: float = 0.0): r''' Pads :attr:`val` to make the number of elements of :attr:`data` equal to the octree node number. Args: data (torch.Tensor): The input tensor with its number of elements equal to the non-empty octree node number. octree (Octree): The corresponding octree. depth (int): The depth of current octree. val (float): The padding value. (Default: :obj:`0.0`) ''' idx = octree.nempty_index(depth) size = (octree.nnum[depth], data.shape[1]) # (N, C) out = torch.full(size, val, dtype=data.dtype, device=data.device) out[idx] = data return out
[docs]def octree_depad(data: torch.Tensor, octree: Octree, depth: int): r''' Reverse operation of :func:`octree_depad`. Please refer to :func:`octree_depad` for the meaning of the arguments. ''' idx = octree.nempty_index(depth) return data[idx]