Source code for ocnn.nn.octree2col

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch
import torch.nn

from ocnn.octree import Octree
from ocnn.utils import scatter_add


[docs]def octree2col(data: torch.Tensor, octree: Octree, depth: int, kernel_size: str = '333', stride: int = 1, nempty: bool = False): r''' Gathers the neighboring features for convolutions. Args: data (torch.Tensor): The input data. octree (Octree): The corresponding octree. depth (int): The depth of current octree. kernel_size (str): The kernel shape, choose from :obj:`333`, :obj:`311`, :obj:`131`, :obj:`113`, :obj:`222`, :obj:`331`, :obj:`133`, and :obj:`313`. stride (int): The stride of neighborhoods (:obj:`1` or :obj:`2`). If the stride is :obj:`2`, it always returns the neighborhood of the first siblings, and the number of elements of output tensor is :obj:`octree.nnum[depth] / 8`. nempty (bool): If True, only returns the neighborhoods of the non-empty octree nodes. ''' neigh = octree.get_neigh(depth, kernel_size, stride, nempty) size = (neigh.shape[0], neigh.shape[1], data.shape[1]) out = torch.zeros(size, dtype=data.dtype, device=data.device) valid = neigh >= 0 out[valid] = data[neigh[valid]] # (N, K, C) return out
[docs]def col2octree(data: torch.Tensor, octree: Octree, depth: int, kernel_size: str = '333', stride: int = 1, nempty: bool = False): r''' Scatters the convolution features to an octree. Please refer to :func:`octree2col` for the usage of function parameters. ''' neigh = octree.get_neigh(depth, kernel_size, stride, nempty) valid = neigh >= 0 dim_size = octree.nnum_nempty[depth] if nempty else octree.nnum[depth] out = scatter_add(data[valid], neigh[valid], dim=0, dim_size=dim_size) return out