# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import os
import torch
import torch.nn
# import warnings
from torch.autograd import Function
from packaging import version
from typing import List
from ocnn.octree import Octree
from ocnn.utils import scatter_add, xavier_uniform_, resize_with_last_val, list2str
from .octree2col import octree2col, col2octree
from .octree_pad import octree_pad, octree_depad
from .octree_conv_t import octree_conv_triton
DISABLE_TRITON = os.getenv('OCNN_DISABLE_TRITON', '0') == '1'
class OctreeConvBase:
def __init__(self, in_channels: int, out_channels: int,
kernel_size: List[int] = [3], stride: int = 1,
nempty: bool = False, max_buffer: int = int(2e8)):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = resize_with_last_val(kernel_size)
self.kernel = list2str(self.kernel_size)
self.stride = stride
self.nempty = nempty
self.max_buffer = max_buffer # about 200M
self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
self.in_conv = in_channels if self.is_conv_layer() else out_channels
self.out_conv = out_channels if self.is_conv_layer() else in_channels
self.weights_shape = (self.kdim, self.in_conv, self.out_conv)
def is_conv_layer(self):
r''' Returns :obj:`True` to indicate this is a convolution layer.
'''
raise NotImplementedError
def setup(self, octree: Octree, depth: int):
r''' Setup the shapes of each tensor.
This function MUST be called before :obj:`forward_gemm`, :obj:`backward_gemm`
and :obj:`weight_gemm`.
'''
# The depth of tensors:
# The in_depth and out_depth are the octree depth of the input and output
# data; neigh_depth is the octree depth of the neighborhood information, as
# well as `col` data, neigh_depth is always the same as the depth of larger
# data when doing octree2col or col2octree.
self.in_depth = depth
self.out_depth = depth
self.neigh_depth = depth
if self.stride == 2:
if self.is_conv_layer():
self.out_depth = depth - 1
else:
self.out_depth = depth + 1
self.neigh_depth = depth + 1
# The height of tensors
if self.nempty:
self.in_h = octree.nnum_nempty[self.in_depth]
self.out_h = octree.nnum_nempty[self.out_depth]
else:
self.in_h = octree.nnum[self.in_depth]
self.out_h = octree.nnum[self.out_depth]
if self.stride == 2:
if self.is_conv_layer():
self.out_h = octree.nnum_nempty[self.out_depth]
else:
self.in_h = octree.nnum_nempty[self.in_depth]
self.in_shape = (self.in_h, self.in_channels)
self.out_shape = (self.out_h, self.out_channels)
# The neighborhood indices
self.neigh = octree.get_neigh(
self.neigh_depth, self.kernel, self.stride, self.nempty)
# The heigh and number of the temporary buffer
self.buffer_n = 1
self.buffer_h = self.neigh.shape[0]
ideal_size = self.buffer_h * self.kdim * self.in_conv
if ideal_size > self.max_buffer:
kc = self.kdim * self.in_conv # make `max_buffer` be divided
max_buffer = self.max_buffer // kc * kc # by `kc` with no remainder
self.buffer_n = (ideal_size + max_buffer - 1) // max_buffer
self.buffer_h = (self.buffer_h + self.buffer_n - 1) // self.buffer_n
self.buffer_shape = (self.buffer_h, self.kdim, self.in_conv)
def check_and_init(self, data: torch.Tensor):
r''' Checks the input data and initializes the shape of output data.
'''
# Check the shape of input data
check = tuple(data.shape) == self.in_shape
assert check, ('The shape of input data is wrong: ' +
'expected {}, got {}.'.format(self.in_shape, data.shape))
# Init the output data
out = data.new_zeros(self.out_shape)
return out
def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
weights: torch.Tensor):
r''' Peforms the forward pass of octree-based convolution.
'''
# Initialize the buffer
buffer = data.new_empty(self.buffer_shape)
# Loop over each sub-matrix
for i in range(self.buffer_n):
start = i * self.buffer_h
end = (i + 1) * self.buffer_h
# The boundary case in the last iteration
if end > self.neigh.shape[0]:
dis = end - self.neigh.shape[0]
end = self.neigh.shape[0]
buffer, _ = buffer.split([self.buffer_h-dis, dis])
# Perform octree2col
neigh_i = self.neigh[start:end]
valid = neigh_i >= 0
buffer.fill_(0)
buffer[valid] = data[neigh_i[valid]]
# The sub-matrix gemm
out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
return out
def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
weights: torch.Tensor):
r''' Performs the backward pass of octree-based convolution.
'''
# Loop over each sub-matrix
for i in range(self.buffer_n):
start = i * self.buffer_h
end = (i + 1) * self.buffer_h
# The boundary case in the last iteration
if end > self.neigh.shape[0]:
end = self.neigh.shape[0]
# The sub-matrix gemm
buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
# Performs col2octree
neigh_i = self.neigh[start:end]
valid = neigh_i >= 0
out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
return out
def weight_gemm(
self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
r''' Computes the gradient of the weight matrix.
'''
# Record the shape of out
out_shape = out.shape
out = out.flatten(0, 1)
# Initialize the buffer
buffer = data.new_empty(self.buffer_shape)
# Loop over each sub-matrix
for i in range(self.buffer_n):
start = i * self.buffer_h
end = (i + 1) * self.buffer_h
# The boundary case in the last iteration
if end > self.neigh.shape[0]:
d = end - self.neigh.shape[0]
end = self.neigh.shape[0]
buffer, _ = buffer.split([self.buffer_h-d, d])
# Perform octree2col
neigh_i = self.neigh[start:end]
valid = neigh_i >= 0
buffer.fill_(0)
buffer[valid] = data[neigh_i[valid]]
# Accumulate the gradient via gemm
out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
return out.view(out_shape)
class _OctreeConv(OctreeConvBase):
r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
'''
def is_conv_layer(self): return True
class _OctreeDeconv(OctreeConvBase):
r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
'''
def is_conv_layer(self): return False
class OctreeConvFunction(Function):
r''' Wrap the octree convolution for auto-diff.
'''
@staticmethod
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
depth: int, in_channels: int, out_channels: int,
kernel_size: List[int] = [3, 3, 3], stride: int = 1,
nempty: bool = False, max_buffer: int = int(2e8)):
octree_conv = _OctreeConv(
in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
octree_conv.setup(octree, depth)
out = octree_conv.check_and_init(data)
weights = weights.to(data.dtype)
out = octree_conv.forward_gemm(out, data, weights)
ctx.save_for_backward(data, weights)
ctx.octree_conv = octree_conv
return out
@staticmethod
def backward(ctx, grad):
data, weights = ctx.saved_tensors
octree_conv = ctx.octree_conv
grad_out = None
if ctx.needs_input_grad[0]:
grad_out = torch.zeros_like(data)
grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
grad_w = None
if ctx.needs_input_grad[1]:
grad_w = torch.zeros_like(weights)
grad_w = octree_conv.weight_gemm(grad_w, data, grad)
return (grad_out, grad_w) + (None,) * 8
class OctreeDeconvFunction(Function):
r''' Wrap the octree deconvolution for auto-diff.
'''
@staticmethod
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
depth: int, in_channels: int, out_channels: int,
kernel_size: List[int] = [3, 3, 3], stride: int = 1,
nempty: bool = False, max_buffer: int = int(2e8)):
octree_deconv = _OctreeDeconv(
in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
octree_deconv.setup(octree, depth)
out = octree_deconv.check_and_init(data)
weights = weights.to(data.dtype)
out = octree_deconv.backward_gemm(out, data, weights)
ctx.save_for_backward(data, weights)
ctx.octree_deconv = octree_deconv
return out
@staticmethod
def backward(ctx, grad):
data, weights = ctx.saved_tensors
octree_deconv = ctx.octree_deconv
grad_out = None
if ctx.needs_input_grad[0]:
grad_out = torch.zeros_like(data)
grad_out = octree_deconv.forward_gemm(grad_out, grad, weights)
grad_w = None
if ctx.needs_input_grad[1]:
grad_w = torch.zeros_like(weights)
grad_w = octree_deconv.weight_gemm(grad_w, grad, data)
return (grad_out, grad_w) + (None,) * 8
# alias
octree_conv = OctreeConvFunction.apply
octree_deconv = OctreeDeconvFunction.apply
[docs]class OctreeConv(OctreeConvBase, torch.nn.Module):
r''' Performs octree convolution.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`,
:obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`,
:obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`.
stride (int): The stride of the convolution (:obj:`1` or :obj:`2`).
nempty (bool): If True, only performs the convolution on non-empty
octree nodes.
method (str): Which implementation to use. Options are
:obj:`'explicit_gemm'`, :obj:`'block_gemm'`, and :obj:`'triton'`.
:obj:`'explicit_gemm'` builds the full column matrix via
octree2col/col2octree and then uses GEMM; this can use a large amount of
memory. :obj:`'block_gemm'` computes in smaller blocks to reduce peak
memory at some runtime cost. :obj:`'triton'` uses the implicit kernel and
requires kernel_size=[3,3,3], stride=1, CUDA, and PyTorch >= 2.8.0.
use_bias (bool): If True, add a bias term to the convolution.
max_buffer (int): The maximum number of elements in the buffer, used when
:attr:`method` is 'block_gemm'.
.. note::
Each non-empty octree node has exactly 8 children nodes, among which some
children nodes are non-empty and some are empty. If :attr:`nempty` is true,
the convolution is performed on non-empty octree nodes only, which is exactly
the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the
convolution is performed on all octree nodes, which is essential for shape
reconstruction tasks and can also be used in classification and segmentation
(with slightly better performance and larger memory cost).
'''
def __init__(self, in_channels: int, out_channels: int,
kernel_size: List[int] = [3], stride: int = 1,
nempty: bool = False, method: str = 'triton',
use_bias: bool = False, max_buffer: int = int(2e8)):
super().__init__(
in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
self.use_bias = use_bias
self.method = self.check_method(method)
self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
self.bias = (torch.nn.Parameter(torch.Tensor(out_channels))
if use_bias else None)
self.reset_parameters()
[docs] def check_method(self, method: str):
smaller_than_270 = version.parse(torch.__version__) < version.parse('2.7.0')
if method == 'triton':
if (self.kernel != '333' or self.stride != 1 or DISABLE_TRITON or
smaller_than_270 or torch.cuda.is_available() is False):
method = 'block_gemm'
# warnings.warn(
# 'The triton implementation only supports kernel_size=[3,3,3], '
# 'stride=1, and requires PyTorch >= 2.8.0 and CUDA. '
# 'Falling back to block_gemm.', RuntimeWarning, stacklevel=2,)
return method
[docs] def reset_parameters(self):
xavier_uniform_(self.weights)
if self.use_bias:
torch.nn.init.zeros_(self.bias)
[docs] def is_conv_layer(self): return True
[docs] def explicit_gemm(self, data: torch.Tensor, octree: Octree, depth: int):
r''' Performs the convolution via explicitly constructing the `col` data.
'''
col = octree2col(
data, octree, depth, self.kernel, self.stride, self.nempty)
out = torch.mm(col.flatten(1), self.weights.flatten(0, 1))
if self.use_bias:
out += self.bias
return out
[docs] def block_gemm(self, data: torch.Tensor, octree: Octree, depth: int):
r''' Performs the convolution in a block manner, which can save the required
runtime memory.
'''
out = octree_conv(
data, self.weights, octree, depth, self.in_channels,
self.out_channels, self.kernel_size, self.stride, self.nempty,
self.max_buffer)
if self.use_bias:
out += self.bias
return out
[docs] def implicit_gemm(self, data: torch.Tensor, octree: Octree, depth: int):
r''' Performs the convolution via the implicit GEMM kernel implemented in Triton.
'''
weight = self.weights.permute(2, 0, 1) # (V,Ci,Co) -> (Co,V,Ci)
neigh = octree.get_neigh(depth, self.kernel, self.stride, self.nempty)
out = octree_conv_triton(data, weight, self.bias, neigh)
return out
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int):
r''' Defines the octree convolution.
Args:
data (torch.Tensor): The input data.
octree (Octree): The corresponding octree.
depth (int): The depth of current octree.
'''
if self.method == 'explicit_gemm':
out = self.explicit_gemm(data, octree, depth)
elif self.method == 'block_gemm':
out = self.block_gemm(data, octree, depth)
elif self.method == 'triton':
out = self.implicit_gemm(data, octree, depth)
else:
raise ValueError('Unknown method: {}'.format(self.method))
if self.stride == 2 and not self.nempty:
out = octree_pad(out, octree, depth-1)
return out
def extra_repr(self) -> str:
r''' Sets the extra representation of the module.
'''
return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
'nempty={}, bias={}, method={}').format(self.in_channels,
self.out_channels, self.kernel_size, self.stride, self.nempty,
self.use_bias, self.method) # noqa
[docs]class OctreeDeconv(OctreeConv):
r''' Performs octree deconvolution.
Please refer to :class:`OctreeConv` for the meaning of the arguments.
'''
[docs] def is_conv_layer(self): return False
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int):
r''' Defines the octree deconvolution.
Please refer to :meth:`OctreeConv.forward` for the meaning of the arguments.
'''
depth_col = depth
if self.stride == 2:
depth_col = depth + 1
if not self.nempty:
data = octree_depad(data, octree, depth)
if self.method == 'explicit_gemm':
col = torch.mm(data, self.weights.flatten(0, 1).t())
col = col.view(col.shape[0], self.kdim, -1)
out = col2octree(
col, octree, depth_col, self.kernel, self.stride, self.nempty)
else:
out = octree_deconv(
data, self.weights, octree, depth, self.in_channels,
self.out_channels, self.kernel_size, self.stride, self.nempty,
self.max_buffer)
if self.use_bias:
out += self.bias
return out