ocnn.nn

octree2voxel

Converts the input feature to the full-voxel-based representation.

octree2col

Gathers the neighboring features for convolutions.

col2octree

Scatters the convolution features to an octree.

octree_pad

Pads val to make the number of elements of data equal to the octree node number.

octree_depad

Reverse operation of octree_depad().

octree_nearest_pts

The nearest-neighbor interpolatation with input points.

octree_linear_pts

Linear interpolatation with input points.

octree_max_pool

Performs octree max pooling with kernel size 2 and stride 2.

octree_max_unpool

Performs octree max unpooling.

octree_global_pool

Performs octree global average pooling.

octree_avg_pool

Performs octree average pooling.

Octree2Voxel

Converts the input feature to the full-voxel-based representation.

OctreeMaxPool

Performs octree max pooling.

OctreeMaxUnpool

Performs octree max unpooling.

OctreeGlobalPool

Performs octree global pooling.

OctreeAvgPool

Performs octree average pooling.

OctreeConv

Performs octree convolution.

OctreeDeconv

Performs octree deconvolution.

OctreeGroupConv

Performs octree-based group convolution.

OctreeDWConv

Performs octree-based depth-wise convolution.

OctreeInterp

Interpolates the points with an octree feature.

OctreeUpsample

Upsamples the octree node features from depth to (target_depth).

OctreeInstanceNorm

An instance normalization layer for the octree.

OctreeBatchNorm

alias of BatchNorm1d

OctreeGroupNorm

An group normalization layer for the octree.

OctreeNorm

A normalization layer for the octree.

OctreeDropPath

Drop paths (Stochastic Depth) per octree when applied in main path of residual blocks.

DropPath

Drop paths (Stochastic Depth) per token when applied in main path of residual blocks, following the logic of timm.models.layers.DropPath().

search_value

Searches values according to sorted shuffled keys.

octree_align

Wraps octree_align() to take octrees as input for convenience.

OctreeConvTriton

Performs octree convolution.

OctreeConvT

alias of OctreeConvTriton

convert_conv_triton

Convert OctreeConv modules to OctreeConvTriton modules in a network.

octree2voxel(data: Tensor, octree: Octree, depth: int, nempty: bool = False)[source]

Converts the input feature to the full-voxel-based representation.

Parameters:
  • data (torch.Tensor) – The input feature.

  • octree (Octree) – The corresponding octree.

  • depth (int) – The depth of current octree.

  • nempty (bool) – If True, data only contains the features of non-empty octree nodes.

octree2col(data: Tensor, octree: Octree, depth: int, kernel_size: str = '333', stride: int = 1, nempty: bool = False)[source]

Gathers the neighboring features for convolutions.

Parameters:
  • data (torch.Tensor) – The input data.

  • octree (Octree) – The corresponding octree.

  • depth (int) – The depth of current octree.

  • kernel_size (str) – The kernel shape, choose from 333, 311, 131, 113, 222, 331, 133, and 313.

  • stride (int) – The stride of neighborhoods (1 or 2). If the stride is 2, it always returns the neighborhood of the first siblings, and the number of elements of output tensor is octree.nnum[depth] / 8.

  • nempty (bool) – If True, only returns the neighborhoods of the non-empty octree nodes.

col2octree(data: Tensor, octree: Octree, depth: int, kernel_size: str = '333', stride: int = 1, nempty: bool = False)[source]

Scatters the convolution features to an octree.

Please refer to octree2col() for the usage of function parameters.

octree_pad(data: Tensor, octree: Octree, depth: int, val: float = 0.0)[source]

Pads val to make the number of elements of data equal to the octree node number.

Parameters:
  • data (torch.Tensor) – The input tensor with its number of elements equal to the non-empty octree node number.

  • octree (Octree) – The corresponding octree.

  • depth (int) – The depth of current octree.

  • val (float) – The padding value. (Default: 0.0)

octree_depad(data: Tensor, octree: Octree, depth: int)[source]

Reverse operation of octree_depad().

Please refer to octree_depad() for the meaning of the arguments.

octree_nearest_pts(data: Tensor, octree: Octree, depth: int, pts: Tensor, nempty: bool = False, bound_check: bool = False)[source]

The nearest-neighbor interpolatation with input points.

Parameters:
  • data (torch.Tensor) – The input data.

  • octree (Octree) – The octree to interpolate.

  • depth (int) – The depth of the data.

  • pts (torch.Tensor) – The coordinates of the points with shape (N, 4), i.e. N x (x, y, z, batch).

  • nempty (bool) – If true, the data only contains features of non-empty octree nodes

  • bound_check (bool) – If true, check whether the point is in [0, 2^depth).

Note

The pts MUST be scaled into [0, 2^depth).

octree_linear_pts(data: Tensor, octree: Octree, depth: int, pts: Tensor, nempty: bool = False, bound_check: bool = False)[source]

Linear interpolatation with input points.

Refer to octree_nearest_pts() for the meaning of the arguments.

octree_max_pool(data: Tensor, octree: Octree, depth: int, nempty: bool = False, return_indices: bool = False)[source]

Performs octree max pooling with kernel size 2 and stride 2.

Parameters:
  • data (torch.Tensor) – The input tensor.

  • octree (Octree) – The corresponding octree.

  • depth (int) – The depth of current octree. After pooling, the corresponding depth decreased by 1.

  • nempty (bool) – If True, data contains only features of non-empty octree nodes.

  • return_indices (bool) – If True, returns the indices, which can be used in octree_max_unpool().

octree_max_unpool(data: Tensor, indices: Tensor, octree: Octree, depth: int, nempty: bool = False)[source]

Performs octree max unpooling.

Parameters:
  • data (torch.Tensor) – The input tensor.

  • indices (torch.Tensor) – The indices returned by octree_max_pool(). The depth of indices is larger by 1 than data.

  • octree (Octree) – The corresponding octree.

  • depth (int) – The depth of current data. After unpooling, the corresponding depth increases by 1.

octree_global_pool(data: Tensor, octree: Octree, depth: int, nempty: bool = False)[source]

Performs octree global average pooling.

Parameters:
  • data (torch.Tensor) – The input tensor.

  • octree (Octree) – The corresponding octree.

  • depth (int) – The depth of current octree.

  • nempty (bool) – If True, data contains only features of non-empty octree nodes.

octree_avg_pool(data: Tensor, octree: Octree, depth: int, kernel: str, stride: int = 2, nempty: bool = False)[source]

Performs octree average pooling.

Parameters:
  • data (torch.Tensor) – The input tensor.

  • octree (Octree) – The corresponding octree.

  • depth (int) – The depth of current octree.

  • kernel (str) – The kernel size, like ‘333’, ‘222’.

  • stride (int) – The stride of the pooling.

  • nempty (bool) – If True, data contains only features of non-empty octree nodes.

class Octree2Voxel(nempty: bool = False)[source]

Converts the input feature to the full-voxel-based representation.

Please refer to octree2voxel() for details.

forward(data: Tensor, octree: Octree, depth: int)[source]
class OctreeMaxPool(nempty: bool = False, return_indices: bool = False)[source]

Performs octree max pooling.

Please refer to octree_max_pool() for details.

forward(data: Tensor, octree: Octree, depth: int)[source]
class OctreeMaxUnpool(nempty: bool = False)[source]

Performs octree max unpooling.

Please refer to octree_max_unpool() for details.

forward(data: Tensor, indices: Tensor, octree: Octree, depth: int)[source]
class OctreeGlobalPool(nempty: bool = False)[source]

Performs octree global pooling.

Please refer to octree_global_pool() for details.

forward(data: Tensor, octree: Octree, depth: int)[source]
class OctreeAvgPool(kernel_size: List[int], stride: int, nempty: bool = False)[source]

Performs octree average pooling.

Please refer to octree_avg_pool() for details.

forward(data: Tensor, octree: Octree, depth: int)[source]
class OctreeConv(in_channels: int, out_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, method: str = 'triton', use_bias: bool = False, max_buffer: int = 200000000)[source]

Performs octree convolution.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • kernel_size (List(int)) – The kernel shape, choose from [3], [2], [3,3,3], [3,1,1], [1,3,1], [1,1,3], [2,2,2], [3,3,1], [1,3,3], and [3,1,3].

  • stride (int) – The stride of the convolution (1 or 2).

  • nempty (bool) – If True, only performs the convolution on non-empty octree nodes.

  • method (str) – Which implementation to use. Options are 'explicit_gemm', 'block_gemm', and 'triton'. 'explicit_gemm' builds the full column matrix via octree2col/col2octree and then uses GEMM; this can use a large amount of memory. 'block_gemm' computes in smaller blocks to reduce peak memory at some runtime cost. 'triton' uses the implicit kernel and requires kernel_size=[3,3,3], stride=1, CUDA, and PyTorch >= 2.8.0.

  • use_bias (bool) – If True, add a bias term to the convolution.

  • max_buffer (int) – The maximum number of elements in the buffer, used when method is ‘block_gemm’.

Note

Each non-empty octree node has exactly 8 children nodes, among which some children nodes are non-empty and some are empty. If nempty is true, the convolution is performed on non-empty octree nodes only, which is exactly the same as SparseConvNet and MinkowsiNet; if nempty is false, the convolution is performed on all octree nodes, which is essential for shape reconstruction tasks and can also be used in classification and segmentation (with slightly better performance and larger memory cost).

check_method(method: str)[source]
reset_parameters()[source]
is_conv_layer()[source]

Returns True to indicate this is a convolution layer.

explicit_gemm(data: Tensor, octree: Octree, depth: int)[source]

Performs the convolution via explicitly constructing the col data.

block_gemm(data: Tensor, octree: Octree, depth: int)[source]

Performs the convolution in a block manner, which can save the required runtime memory.

implicit_gemm(data: Tensor, octree: Octree, depth: int)[source]

Performs the convolution via the implicit GEMM kernel implemented in Triton.

forward(data: Tensor, octree: Octree, depth: int)[source]

Defines the octree convolution.

Parameters:
  • data (torch.Tensor) – The input data.

  • octree (Octree) – The corresponding octree.

  • depth (int) – The depth of current octree.

class OctreeDeconv(in_channels: int, out_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, method: str = 'triton', use_bias: bool = False, max_buffer: int = 200000000)[source]

Performs octree deconvolution.

Please refer to OctreeConv for the meaning of the arguments.

is_conv_layer()[source]

Returns True to indicate this is a convolution layer.

forward(data: Tensor, octree: Octree, depth: int)[source]

Defines the octree deconvolution.

Please refer to OctreeConv.forward() for the meaning of the arguments.

class OctreeGroupConv(in_channels: int, out_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, use_bias: bool = False, group: int = 1)[source]

Performs octree-based group convolution.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • kernel_size (List(int)) – The kernel shape, choose from [3], [2], [3,3,3], [3,1,1], [1,3,1], [1,1,3], [2,2,2], [3,3,1], [1,3,3], and [3,1,3].

  • stride (int) – The stride of the convolution (1 or 2).

  • nempty (bool) – If True, only performs the convolution on non-empty octree nodes.

  • use_bias (bool) – If True, add a bias term to the convolution.

  • group (int) – The number of groups.

Note

Perform octree-based group convolution with a for-loop. The performance is not optimal. Use this module only when the group number is small, otherwise it may be slow.

forward(data: Tensor, octree: Octree, depth: int)[source]

Defines the octree-based group convolution.

Parameters:
  • data (torch.Tensor) – The input data.

  • octree (Octree) – The corresponding octree.

  • depth (int) – The depth of current octree.

class OctreeDWConv(in_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, use_bias: bool = False, max_buffer: int = 200000000)[source]

Performs octree-based depth-wise convolution.

Please refer to ocnn.nn.OctreeConv for the meaning of the arguments.

Note

This implementation uses the torch.einsum() and I find that the speed is relatively slow. Further optimization is needed to speed it up.

reset_parameters()[source]
forward(data: Tensor, octree: Octree, depth: int)[source]
class OctreeInterp(method: str = 'linear', nempty: bool = False, bound_check: bool = False, rescale_pts: bool = True)[source]

Interpolates the points with an octree feature.

Refer to octree_nearest_pts() for a description of arguments.

forward(data: Tensor, octree: Octree, depth: int, pts: Tensor, bbmin: Tensor | float = -1, bbmax: Tensor | float = 1)[source]
class OctreeUpsample(method: str = 'linear', nempty: bool = False)[source]

Upsamples the octree node features from depth to (target_depth).

Refer to octree_nearest_pts for details.

forward(data: Tensor, octree: Octree, depth: int, target_depth: int | None = None)[source]
class OctreeInstanceNorm(in_channels: int, nempty: bool = False)[source]

An instance normalization layer for the octree.

OctreeBatchNorm

alias of BatchNorm1d

class OctreeGroupNorm(in_channels: int, group: int, nempty: bool = False, min_group_channels: int = 4)[source]

An group normalization layer for the octree.

reset_parameters()[source]
forward(data: Tensor, octree: Octree, depth: int)[source]
class OctreeNorm(in_channels: int, norm_type: str = 'batch_norm', group: int = 32, min_group_channels: int = 4)[source]

A normalization layer for the octree. It encapsulates octree-based batch, group and instance normalization.

forward(x: Tensor, octree: Octree | None = None, depth: int | None = None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class OctreeDropPath(drop_prob: float = 0.0, nempty: bool = False, scale_by_keep: bool = True)[source]

Drop paths (Stochastic Depth) per octree when applied in main path of residual blocks.

Parameters:
  • drop_prob (int) – The probability of drop paths.

  • nempty (bool) – Indicate whether the input data only contains features of the non-empty octree nodes or not.

  • scale_by_keep (bool) – Whether to scale the kept features proportionally.

forward(data: Tensor, octree: Octree, depth: int, batch_id: Tensor | None = None)[source]

Defines the drop path forward function.

Parameters:
  • data (torch.Tensor) – The input features of shape (N, C), where N is the number of octree nodes and C is the number of channels.

  • octree (Octree) – The input octree.

  • depth (int) – The depth of the octree layer.

  • batch_id (torch.Tensor, optional) – The batch indices of the octree nodes. If not provided, it will be extracted from the octree.

class DropPath(drop_prob: float = 0.0, scale_by_keep: bool = True, **kwargs)[source]

Drop paths (Stochastic Depth) per token when applied in main path of residual blocks, following the logic of timm.models.layers.DropPath().

Parameters:
  • drop_prob (int) – The probability of drop paths.

  • scale_by_keep (bool) – Whether to scale the kept features proportionally.

forward(data: Tensor, **kwargs)[source]

Defines the drop path forward function.

Parameters:

data (torch.Tensor) – The input features of shape (N, C), where N is the number of tokens and C is the number of channels.

search_value(value: Tensor, key: Tensor, query: Tensor)[source]

Searches values according to sorted shuffled keys.

Parameters:
  • value (torch.Tensor) – The input tensor with shape (N, C).

  • key (torch.Tensor) – The key tensor corresponds to value with shape (N,), which contains sorted shuffled keys of an octree.

  • query (torch.Tensor) – The query tensor, which also contains shuffled keys.

octree_align(value: Tensor, octree: Octree, octree_query: Octree, depth: int, nempty: bool = False)[source]

Wraps octree_align() to take octrees as input for convenience.

class OctreeConvTriton(in_channels: int, out_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, method: str = 'triton', use_bias: bool = False, max_buffer: int = 200000000)[source]

Performs octree convolution.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • kernel_size (List(int)) – The kernel shape, only [3] and [3,3,3] are supported now for the triton implementation.

  • stride (int) – The stride of the convolution, only 1 is supported now.

  • nempty (bool) – If True, only performs the convolution on non-empty octree nodes; otherwise, performs the convolution on all octree nodes.

  • use_bias (bool) – If True, add a bias term to the convolution.

Note

Each non-empty octree node has exactly 8 children nodes, among which some children nodes are non-empty and some are empty. If nempty is true, the convolution is performed on non-empty octree nodes only, which is exactly the same as SparseConvNet and MinkowsiNet; if nempty is false, the convolution is performed on all octree nodes, which is essential for shape reconstruction tasks and can also be used in classification and segmentation (with slightly better performance and larger memory cost).

reset_parameters()[source]
forward(data: Tensor, octree: Octree, depth: int)[source]

Defines the octree convolution.

Parameters:
  • data (torch.Tensor) – The input data.

  • octree (Octree) – The corresponding octree.

  • depth (int) – The depth of current octree.

OctreeConvT

alias of OctreeConvTriton

convert_conv_triton(module: Module) Module[source]

Convert OctreeConv modules to OctreeConvTriton modules in a network.

Parameters:

module (torch.nn.Module) – The input module.