# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import torch
import torch.nn
from torch.autograd import Function
from typing import List
from ocnn.octree import Octree
from ocnn.utils import scatter_add, xavier_uniform_
from .octree_pad import octree_pad
from .octree_conv import OctreeConvBase
class OctreeDWConvBase(OctreeConvBase):
def __init__(self, in_channels: int, kernel_size: List[int] = [3],
stride: int = 1, nempty: bool = False,
max_buffer: int = int(2e8)):
super().__init__(
in_channels, in_channels, kernel_size, stride, nempty, max_buffer)
self.weights_shape = (self.kdim, 1, self.out_channels)
def is_conv_layer(self): return True
def forward_gemm(self, out: torch.Tensor, data: torch.Tensor,
weights: torch.Tensor):
r''' Peforms the forward pass of octree-based convolution.
'''
# Initialize the buffer
buffer = data.new_empty(self.buffer_shape)
# Loop over each sub-matrix
for i in range(self.buffer_n):
start = i * self.buffer_h
end = (i + 1) * self.buffer_h
# The boundary case in the last iteration
if end > self.neigh.shape[0]:
dis = end - self.neigh.shape[0]
end = self.neigh.shape[0]
buffer, _ = buffer.split([self.buffer_h-dis, dis])
# Perform octree2col
neigh_i = self.neigh[start:end]
valid = neigh_i >= 0
buffer.fill_(0)
buffer[valid] = data[neigh_i[valid]]
# The sub-matrix gemm
# out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1))
out[start:end] = torch.einsum('ikc,kc->ic', buffer, weights.flatten(0, 1))
return out
def backward_gemm(self, out: torch.Tensor, grad: torch.Tensor,
weights: torch.Tensor):
r''' Performs the backward pass of octree-based convolution.
'''
# Loop over each sub-matrix
for i in range(self.buffer_n):
start = i * self.buffer_h
end = (i + 1) * self.buffer_h
# The boundary case in the last iteration
if end > self.neigh.shape[0]:
end = self.neigh.shape[0]
# The sub-matrix gemm
# buffer = torch.mm(grad[start:end], weights.flatten(0, 1).t())
# buffer = buffer.view(-1, self.buffer_shape[1], self.buffer_shape[2])
buffer = torch.einsum(
'ic,kc->ikc', grad[start:end], weights.flatten(0, 1))
# Performs col2octree
neigh_i = self.neigh[start:end]
valid = neigh_i >= 0
out = scatter_add(buffer[valid], neigh_i[valid], dim=0, out=out)
return out
def weight_gemm(self, out: torch.Tensor, data: torch.Tensor, grad: torch.Tensor):
r''' Computes the gradient of the weight matrix.
'''
# Record the shape of out
out_shape = out.shape
out = out.flatten(0, 1)
# Initialize the buffer
buffer = data.new_empty(self.buffer_shape)
# Loop over each sub-matrix
for i in range(self.buffer_n):
start = i * self.buffer_h
end = (i + 1) * self.buffer_h
# The boundary case in the last iteration
if end > self.neigh.shape[0]:
d = end - self.neigh.shape[0]
end = self.neigh.shape[0]
buffer, _ = buffer.split([self.buffer_h-d, d])
# Perform octree2col
neigh_i = self.neigh[start:end]
valid = neigh_i >= 0
buffer.fill_(0)
buffer[valid] = data[neigh_i[valid]]
# Accumulate the gradient via gemm
# out.addmm_(buffer.flatten(1, 2).t(), grad[start:end])
out += torch.einsum('ikc,ic->kc', buffer, grad[start:end])
return out.view(out_shape)
class OctreeDWConvFunction(Function):
r''' Wrap the octree convolution for auto-diff.
'''
@staticmethod
def forward(ctx, data: torch.Tensor, weights: torch.Tensor, octree: Octree,
depth: int, in_channels: int, kernel_size: List[int] = [3, 3, 3],
stride: int = 1, nempty: bool = False, max_buffer: int = int(2e8)):
octree_conv = OctreeDWConvBase(
in_channels, kernel_size, stride, nempty, max_buffer)
octree_conv.setup(octree, depth)
out = octree_conv.check_and_init(data)
weights = weights.to(data.dtype)
out = octree_conv.forward_gemm(out, data, weights)
ctx.save_for_backward(data, weights)
ctx.octree_conv = octree_conv
return out
@staticmethod
def backward(ctx, grad):
data, weights = ctx.saved_tensors
octree_conv = ctx.octree_conv
grad_out = None
if ctx.needs_input_grad[0]:
grad_out = torch.zeros_like(data)
grad_out = octree_conv.backward_gemm(grad_out, grad, weights)
grad_w = None
if ctx.needs_input_grad[1]:
grad_w = torch.zeros_like(weights)
grad_w = octree_conv.weight_gemm(grad_w, data, grad)
return (grad_out, grad_w) + (None,) * 7
# alias
octree_dwconv = OctreeDWConvFunction.apply
[docs]class OctreeDWConv(OctreeDWConvBase, torch.nn.Module):
r''' Performs octree-based depth-wise convolution.
Please refer to :class:`ocnn.nn.OctreeConv` for the meaning of the arguments.
.. note::
This implementation uses the :func:`torch.einsum` and I find that the speed
is relatively slow. Further optimization is needed to speed it up.
'''
def __init__(self, in_channels: int, kernel_size: List[int] = [3],
stride: int = 1, nempty: bool = False, use_bias: bool = False,
max_buffer: int = int(2e8)):
super().__init__(in_channels, kernel_size, stride, nempty, max_buffer)
self.use_bias = use_bias
self.weights = torch.nn.Parameter(torch.Tensor(*self.weights_shape))
if self.use_bias:
self.bias = torch.nn.Parameter(torch.Tensor(in_channels))
self.reset_parameters()
[docs] def reset_parameters(self):
xavier_uniform_(self.weights)
if self.use_bias:
torch.nn.init.zeros_(self.bias)
[docs] def forward(self, data: torch.Tensor, octree: Octree, depth: int):
r''''''
out = octree_dwconv(
data, self.weights, octree, depth, self.in_channels,
self.kernel_size, self.stride, self.nempty, self.max_buffer)
if self.use_bias:
out += self.bias
if self.stride == 2 and not self.nempty:
out = octree_pad(out, octree, depth-1)
return out
def extra_repr(self) -> str:
return ('in_channels={}, out_channels={}, kernel_size={}, stride={}, '
'nempty={}, bias={}').format(self.in_channels, self.out_channels,
self.kernel_size, self.stride, self.nempty, self.use_bias) # noqa