# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import torch
import torch.sparse
from typing import Union, Optional
import ocnn
from ocnn.octree import Octree
[docs]def octree_nearest_pts(data: torch.Tensor, octree: Octree, depth: int,
pts: torch.Tensor, nempty: bool = False,
bound_check: bool = False):
''' The nearest-neighbor interpolatation with input points.
Args:
data (torch.Tensor): The input data.
octree (Octree): The octree to interpolate.
depth (int): The depth of the data.
pts (torch.Tensor): The coordinates of the points with shape :obj:`(N, 4)`,
i.e. :obj:`N x (x, y, z, batch)`.
nempty (bool): If true, the :attr:`data` only contains features of non-empty
octree nodes
bound_check (bool): If true, check whether the point is in :obj:`[0, 2^depth)`.
.. note::
The :attr:`pts` MUST be scaled into :obj:`[0, 2^depth)`.
'''
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
idx = octree.search_xyzb(pts, depth, nempty)
valid = idx > -1 # valid indices
if bound_check:
bound = torch.logical_and(pts[:, :3] >= 0, pts[:, :3] < 2**depth).all(1)
valid = torch.logical_and(valid, bound)
size = (pts.shape[0], data.shape[1])
out = torch.zeros(size, device=data.device, dtype=data.dtype)
out[valid] = data.index_select(0, idx[valid])
return out
[docs]def octree_linear_pts(data: torch.Tensor, octree: Octree, depth: int,
pts: torch.Tensor, nempty: bool = False,
bound_check: bool = False):
''' Linear interpolatation with input points.
Refer to :func:`octree_nearest_pts` for the meaning of the arguments.
'''
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
device = data.device
grid = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
[1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], device=device)
# 1. Neighborhood searching
xyzf = pts[:, :3] - 0.5 # the value is defined on the center of each voxel
xyzi = xyzf.floor() # the integer part (N, 3)
frac = xyzf - xyzi # the fraction part (N, 3)
xyzn = (xyzi.unsqueeze(1) + grid).view(-1, 3)
batch = pts[:, 3].unsqueeze(1).repeat(1, 8).view(-1, 1)
idx = octree.search_xyzb(torch.cat([xyzn, batch], dim=1), depth, nempty)
valid = idx > -1 # valid indices
if bound_check:
bound = torch.logical_and(xyzn >= 0, xyzn < 2**depth).all(1)
valid = torch.logical_and(valid, bound)
idx = idx[valid]
# 2. Build the sparse matrix
npt = pts.shape[0]
ids = torch.arange(npt, device=idx.device)
ids = ids.unsqueeze(1).repeat(1, 8).view(-1)
ids = ids[valid]
indices = torch.stack([ids, idx], dim=0).long()
frac = (1.0 - grid) - frac.unsqueeze(dim=1) # (8, 3) - (N, 1, 3) -> (N, 8, 3)
weight = frac.prod(dim=2).abs().view(-1) # (8*N,)
weight = weight[valid]
h = data.shape[0]
mat = torch.sparse_coo_tensor(indices, weight, [npt, h], device=device)
# 3. Interpolatation
output = torch.sparse.mm(mat, data)
ones = torch.ones(h, 1, dtype=data.dtype, device=device)
norm = torch.sparse.mm(mat, ones)
output = torch.div(output, norm + 1e-12)
return output
[docs]class OctreeInterp(torch.nn.Module):
r''' Interpolates the points with an octree feature.
Refer to :func:`octree_nearest_pts` for a description of arguments.
'''
def __init__(self, method: str = 'linear', nempty: bool = False,
bound_check: bool = False, rescale_pts: bool = True):
super().__init__()
self.method = method
self.nempty = nempty
self.bound_check = bound_check
self.rescale_pts = rescale_pts
self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int,
pts: torch.Tensor, bbmin: Union[torch.Tensor, float] = -1,
bbmax: Union[torch.Tensor, float] = 1):
r''''''
# rescale points from [bbmin, bbmax] to [0, 2^depth]
if self.rescale_pts:
box_size = bbmax - bbmin
if type(box_size) is torch.Tensor:
box_size = box_size.max().item()
assert box_size > 0, 'The bounding box size must be greater than 0.'
pts[:, :3] = (pts[:, :3] - bbmin) * (2**depth / box_size)
return self.func(data, octree, depth, pts, self.nempty, self.bound_check)
def extra_repr(self) -> str:
r''' Sets the extra representation of the module.
'''
return ('method={}, nempty={}, bound_check={}, rescale_pts={}').format(
self.method, self.nempty, self.bound_check, self.rescale_pts) # noqa
def octree_nearest_upsample(data: torch.Tensor, octree: Octree, depth: int,
nempty: bool = False):
r''' Upsamples the octree node features from :attr:`depth` to :attr:`(depth+1)`
with the nearest-neighbor interpolation.
Args:
data (torch.Tensor): The input data.
octree (Octree): The octree to interpolate.
depth (int): The depth of the data.
nempty (bool): If true, the :attr:`data` only contains features of non-empty
octree nodes.
'''
nnum = octree.nnum_nempty[depth] if nempty else octree.nnum[depth]
assert data.shape[0] == nnum, 'The shape of input data is wrong.'
out = data
if not nempty:
out = ocnn.nn.octree_depad(out, octree, depth)
out = out.unsqueeze(1).repeat(1, 8, 1).flatten(end_dim=1)
if nempty:
out = ocnn.nn.octree_depad(out, octree, depth+1) # !!! depth+1
return out
[docs]class OctreeUpsample(torch.nn.Module):
r''' Upsamples the octree node features from :attr:`depth` to
:attr:`(target_depth)`.
Refer to :class:`octree_nearest_pts` for details.
'''
def __init__(self, method: str = 'linear', nempty: bool = False):
super().__init__()
self.method = method
self.nempty = nempty
self.func = octree_linear_pts if method == 'linear' else octree_nearest_pts
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int,
target_depth: Optional[int] = None):
r''''''
if target_depth is None:
target_depth = depth + 1
if target_depth == depth:
return data # return, do nothing
assert target_depth >= depth, 'target_depth must be larger than depth'
if target_depth == depth + 1 and self.method == 'nearest':
return octree_nearest_upsample(data, octree, depth, self.nempty)
xyzb = octree.xyzb(target_depth, self.nempty)
pts = torch.stack(xyzb, dim=1).float()
pts[:, :3] = (pts[:, :3] + 0.5) * (2**(depth - target_depth)) # !!! rescale
return self.func(data, octree, depth, pts, self.nempty)
def extra_repr(self) -> str:
r''' Sets the extra representation of the module.
'''
return ('method={}, nempty={}').format(self.method, self.nempty)