Source code for ocnn.nn.octree_conv

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import os
import torch
import torch.nn
# import warnings
from torch.autograd import Function
from packaging import version
from typing import List

from ocnn.octree import Octree
from ocnn.utils import scatter_add, xavier_uniform_, resize_with_last_val, list2str
from .octree2col import octree2col, col2octree
from .octree_pad import octree_pad, octree_depad
from .octree_conv_t import octree_conv_triton


DISABLE_TRITON = os.getenv('OCNN_DISABLE_TRITON', '0') == '1'


class OctreeConvBase:

  def __init__(self, in_channels: int, out_channels: int,
               kernel_size: List[int] = [3], stride: int = 1,
               nempty: bool = False, max_buffer: int = int(2e8)):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = resize_with_last_val(kernel_size)
    self.kernel = list2str(self.kernel_size)
    self.stride = stride
    self.nempty = nempty
    self.max_buffer = max_buffer  # about 200M

    self.kdim = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]
    self.in_conv = in_channels if self.is_conv_layer() else out_channels
    self.out_conv = out_channels if self.is_conv_layer() else in_channels
    self.weights_shape = (self.kdim, self.in_conv, self.out_conv)

  def is_conv_layer(self):
    r''' Returns :obj:`True` to indicate this is a convolution layer.
    '''

    raise NotImplementedError

  def setup(self, octree: Octree, depth: int):
    r''' Setup the shapes of each tensor.
    This function MUST be called before :obj:`forward_gemm`, :obj:`backward_gemm`
    and :obj:`weight_gemm`.
    '''

    # The depth of tensors:
    # The in_depth and out_depth are the octree depth of the input and output
    # data; neigh_depth is the octree depth of the neighborhood information, as
    # well as `col` data, neigh_depth is always the same as the depth of larger
    # data when doing octree2col or col2octree.
    self.in_depth = depth
    self.out_depth = depth
    self.neigh_depth = depth
    if self.stride == 2:
      if self.is_conv_layer():
        self.out_depth = depth - 1
      else:
        self.out_depth = depth + 1
        self.neigh_depth = depth + 1

    # The height of tensors
    if self.nempty:
      self.in_h = octree.nnum_nempty[self.in_depth]
      self.out_h = octree.nnum_nempty[self.out_depth]
    else:
      self.in_h = octree.nnum[self.in_depth]
      self.out_h = octree.nnum[self.out_depth]
      if self.stride == 2:
        if self.is_conv_layer():
          self.out_h = octree.nnum_nempty[self.out_depth]
        else:
          self.in_h = octree.nnum_nempty[self.in_depth]
    self.in_shape = (self.in_h, self.in_channels)
    self.out_shape = (self.out_h, self.out_channels)

    # The neighborhood indices
    self.neigh = octree.get_neigh(
        self.neigh_depth, self.kernel, self.stride, self.nempty)

    # The heigh and number of the temporary buffer
    self.buffer_n = 1
    self.buffer_h = self.neigh.shape[0]
    ideal_size = self.buffer_h * self.kdim * self.in_conv
    if ideal_size > self.max_buffer:
      kc = self.kdim * self.in_conv            # make `max_buffer` be divided
      max_buffer = self.max_buffer // kc * kc  # by `kc` with no remainder
      self.buffer_n = (ideal_size + max_buffer - 1) // max_buffer
      self.buffer_h = (self.buffer_h + self.buffer_n - 1) // self.buffer_n
    self.buffer_shape = (self.buffer_h, self.kdim, self.in_conv)

  def check_and_init(self, data: torch.Tensor):
    r''' Checks the input data and initializes the shape of output data.
    '''

    # Check the shape of input data
    check = tuple(data.shape) == self.in_shape
    assert check, ('The shape of input data is wrong: ' +
                   'expected {}, got {}.'.format(self.in_shape, data.shape))

    # Init the output data
    out = data.new_zeros(self.out_shape)
    return out

  def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
                   weights: torch.Tensor):
    r''' Peforms the forward pass of octree-based convolution.
    '''

    # Initialize the buffer
    buffer = data.new_empty(self.buffer_shape)

    # Loop over each sub-matrix
    for i in range(self.buffer_n):
      start = i * self.buffer_h
      end = (i + 1) * self.buffer_h

      # The boundary case in the last iteration
      if end > self.neigh.shape[0]:
        dis = end - self.neigh.shape[0]
        end = self.neigh.shape[0]
        buffer, _ = buffer.split([self.buffer_h-dis, dis])

      # Perform octree2col
      neigh_i = self.neigh[start:end]
      valid = neigh_i >= 0
      buffer.fill_(0)
      buffer[valid] = data[neigh_i[valid]]

      # The sub-matrix gemm
      out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))

    return out

  def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
                    weights: torch.Tensor):
    r''' Performs the backward pass of octree-based convolution.
    '''

    # Loop over each sub-matrix
    for i in range(self.buffer_n):
      start = i * self.buffer_h
      end = (i + 1) * self.buffer_h

      # The boundary case in the last iteration
      if end > self.neigh.shape[0]:
        end = self.neigh.shape[0]

      # The sub-matrix gemm
      buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
      buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])

      # Performs col2octree
      neigh_i = self.neigh[start:end]
      valid = neigh_i >= 0
      out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)

    return out

  def weight_gemm(
          self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
    r''' Computes the gradient of the weight matrix.
    '''

    # Record the shape of out
    out_shape = out.shape
    out = out.flatten(0, 1)

    # Initialize the buffer
    buffer = data.new_empty(self.buffer_shape)

    # Loop over each sub-matrix
    for i in range(self.buffer_n):
      start = i * self.buffer_h
      end = (i + 1) * self.buffer_h

      # The boundary case in the last iteration
      if end > self.neigh.shape[0]:
        d = end - self.neigh.shape[0]
        end = self.neigh.shape[0]
        buffer, _ = buffer.split([self.buffer_h-d, d])

      # Perform octree2col
      neigh_i = self.neigh[start:end]
      valid = neigh_i >= 0
      buffer.fill_(0)
      buffer[valid] = data[neigh_i[valid]]

      # Accumulate the gradient via gemm
      out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])

    return out.view(out_shape)


class _OctreeConv(OctreeConvBase):
  r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
  '''

  def is_conv_layer(self): return True


class _OctreeDeconv(OctreeConvBase):
  r''' Instantiates _OctreeConvBase by overriding `is_conv_layer`
  '''

  def is_conv_layer(self): return False


class OctreeConvFunction(Function):
  r''' Wrap the octree convolution for auto-diff.
  '''

  @staticmethod
  def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
              depth: int, in_channels: int, out_channels: int,
              kernel_size: List[int] = [3, 3, 3], stride: int = 1,
              nempty: bool = False, max_buffer: int = int(2e8)):
    octree_conv = _OctreeConv(
        in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
    octree_conv.setup(octree, depth)
    out = octree_conv.check_and_init(data)
    weights = weights.to(data.dtype)
    out = octree_conv.forward_gemm(out, data, weights)

    ctx.save_for_backward(data, weights)
    ctx.octree_conv = octree_conv
    return out

  @staticmethod
  def backward(ctx, grad):
    data, weights = ctx.saved_tensors
    octree_conv = ctx.octree_conv

    grad_out = None
    if ctx.needs_input_grad[0]:
      grad_out = torch.zeros_like(data)
      grad_out = octree_conv.backward_gemm(grad_out, grad, weights)

    grad_w = None
    if ctx.needs_input_grad[1]:
      grad_w = torch.zeros_like(weights)
      grad_w = octree_conv.weight_gemm(grad_w, data, grad)

    return (grad_out, grad_w) + (None,) * 8


class OctreeDeconvFunction(Function):
  r''' Wrap the octree deconvolution for auto-diff.
  '''

  @staticmethod
  def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
              depth: int, in_channels: int, out_channels: int,
              kernel_size: List[int] = [3, 3, 3], stride: int = 1,
              nempty: bool = False, max_buffer: int = int(2e8)):
    octree_deconv = _OctreeDeconv(
        in_channels, out_channels, kernel_size, stride, nempty, max_buffer)
    octree_deconv.setup(octree, depth)
    out = octree_deconv.check_and_init(data)
    weights = weights.to(data.dtype)
    out = octree_deconv.backward_gemm(out, data, weights)

    ctx.save_for_backward(data, weights)
    ctx.octree_deconv = octree_deconv
    return out

  @staticmethod
  def backward(ctx, grad):
    data, weights = ctx.saved_tensors
    octree_deconv = ctx.octree_deconv

    grad_out = None
    if ctx.needs_input_grad[0]:
      grad_out = torch.zeros_like(data)
      grad_out = octree_deconv.forward_gemm(grad_out, grad, weights)

    grad_w = None
    if ctx.needs_input_grad[1]:
      grad_w = torch.zeros_like(weights)
      grad_w = octree_deconv.weight_gemm(grad_w, grad, data)

    return (grad_out, grad_w) + (None,) * 8


# alias
octree_conv = OctreeConvFunction.apply
octree_deconv = OctreeDeconvFunction.apply


[docs]class OctreeConv(OctreeConvBase, torch.nn.Module): r''' Performs octree convolution. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. kernel_size (List(int)): The kernel shape, choose from :obj:`[3]`, :obj:`[2]`, :obj:`[3,3,3]`, :obj:`[3,1,1]`, :obj:`[1,3,1]`, :obj:`[1,1,3]`, :obj:`[2,2,2]`, :obj:`[3,3,1]`, :obj:`[1,3,3]`, and :obj:`[3,1,3]`. stride (int): The stride of the convolution (:obj:`1` or :obj:`2`). nempty (bool): If True, only performs the convolution on non-empty octree nodes. method (str): Which implementation to use. Options are :obj:`'explicit_gemm'`, :obj:`'block_gemm'`, and :obj:`'triton'`. :obj:`'explicit_gemm'` builds the full column matrix via octree2col/col2octree and then uses GEMM; this can use a large amount of memory. :obj:`'block_gemm'` computes in smaller blocks to reduce peak memory at some runtime cost. :obj:`'triton'` uses the implicit kernel and requires kernel_size=[3,3,3], stride=1, CUDA, and PyTorch >= 2.8.0. use_bias (bool): If True, add a bias term to the convolution. max_buffer (int): The maximum number of elements in the buffer, used when :attr:`method` is 'block_gemm'. .. note:: Each non-empty octree node has exactly 8 children nodes, among which some children nodes are non-empty and some are empty. If :attr:`nempty` is true, the convolution is performed on non-empty octree nodes only, which is exactly the same as SparseConvNet and MinkowsiNet; if :attr:`nempty` is false, the convolution is performed on all octree nodes, which is essential for shape reconstruction tasks and can also be used in classification and segmentation (with slightly better performance and larger memory cost). ''' def __init__(self, in_channels: int, out_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, method: str = 'triton', use_bias: bool = False, max_buffer: int = int(2e8)): super().__init__( in_channels, out_channels, kernel_size, stride, nempty, max_buffer) self.use_bias = use_bias self.method = self.check_method(method) self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape)) self.bias = (torch.nn.Parameter(torch.Tensor(out_channels)) if use_bias else None) self.reset_parameters()
[docs] def check_method(self, method: str): smaller_than_270 = version.parse(torch.__version__) < version.parse('2.7.0') if method == 'triton': if (self.kernel != '333' or self.stride != 1 or DISABLE_TRITON or smaller_than_270 or torch.cuda.is_available() is False): method = 'block_gemm' # warnings.warn( # 'The triton implementation only supports kernel_size=[3,3,3], ' # 'stride=1, and requires PyTorch >= 2.8.0 and CUDA. ' # 'Falling back to block_gemm.', RuntimeWarning, stacklevel=2,) return method
[docs] def reset_parameters(self): xavier_uniform_(self.weights) if self.use_bias: torch.nn.init.zeros_(self.bias)
[docs] def is_conv_layer(self): return True
[docs] def explicit_gemm(self, data: torch.Tensor, octree: Octree, depth: int): r''' Performs the convolution via explicitly constructing the `col` data. ''' col = octree2col( data, octree, depth, self.kernel, self.stride, self.nempty) out = torch.mm(col.flatten(1), self.weights.flatten(0, 1)) if self.use_bias: out += self.bias return out
[docs] def block_gemm(self, data: torch.Tensor, octree: Octree, depth: int): r''' Performs the convolution in a block manner, which can save the required runtime memory. ''' out = octree_conv( data, self.weights, octree, depth, self.in_channels, self.out_channels, self.kernel_size, self.stride, self.nempty, self.max_buffer) if self.use_bias: out += self.bias return out
[docs] def implicit_gemm(self, data: torch.Tensor, octree: Octree, depth: int): r''' Performs the convolution via the implicit GEMM kernel implemented in Triton. ''' weight = self.weights.permute(2, 0, 1) # (V,Ci,Co) -> (Co,V,Ci) neigh = octree.get_neigh(depth, self.kernel, self.stride, self.nempty) out = octree_conv_triton(data, weight, self.bias, neigh) return out
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int): r''' Defines the octree convolution. Args: data (torch.Tensor): The input data. octree (Octree): The corresponding octree. depth (int): The depth of current octree. ''' if self.method == 'explicit_gemm': out = self.explicit_gemm(data, octree, depth) elif self.method == 'block_gemm': out = self.block_gemm(data, octree, depth) elif self.method == 'triton': out = self.implicit_gemm(data, octree, depth) else: raise ValueError('Unknown method: {}'.format(self.method)) if self.stride == 2 and not self.nempty: out = octree_pad(out, octree, depth-1) return out
def extra_repr(self) -> str: r''' Sets the extra representation of the module. ''' return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, ' 'nempty={}, bias={}, method={}').format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.nempty, self.use_bias, self.method) # noqa
[docs]class OctreeDeconv(OctreeConv): r''' Performs octree deconvolution. Please refer to :class:`OctreeConv` for the meaning of the arguments. '''
[docs] def is_conv_layer(self): return False
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int): r''' Defines the octree deconvolution. Please refer to :meth:`OctreeConv.forward` for the meaning of the arguments. ''' depth_col = depth if self.stride == 2: depth_col = depth + 1 if not self.nempty: data = octree_depad(data, octree, depth) if self.method == 'explicit_gemm': col = torch.mm(data, self.weights.flatten(0, 1).t()) col = col.view(col.shape[0], self.kdim, -1) out = col2octree( col, octree, depth_col, self.kernel, self.stride, self.nempty) else: out = octree_deconv( data, self.weights, octree, depth, self.in_channels, self.out_channels, self.kernel_size, self.stride, self.nempty, self.max_buffer) if self.use_bias: out += self.bias return out