ocnn.nn
Converts the input feature to the full-voxel-based representation. |
|
Gathers the neighboring features for convolutions. |
|
Scatters the convolution features to an octree. |
|
Pads |
|
Reverse operation of |
|
The nearest-neighbor interpolatation with input points. |
|
Linear interpolatation with input points. |
|
Performs octree max pooling with kernel size 2 and stride 2. |
|
Performs octree max unpooling. |
|
Performs octree global average pooling. |
|
Performs octree average pooling. |
|
Converts the input feature to the full-voxel-based representation. |
|
Performs octree max pooling. |
|
Performs octree max unpooling. |
|
Performs octree global pooling. |
|
Performs octree average pooling. |
|
Performs octree convolution. |
|
Performs octree deconvolution. |
|
Performs octree-based group convolution. |
|
Performs octree-based depth-wise convolution. |
|
Interpolates the points with an octree feature. |
|
Upsamples the octree node features from |
|
An instance normalization layer for the octree. |
|
alias of |
|
An group normalization layer for the octree. |
|
A normalization layer for the octree. |
|
Drop paths (Stochastic Depth) per octree when applied in main path of residual blocks. |
|
Drop paths (Stochastic Depth) per token when applied in main path of residual blocks, following the logic of |
|
Searches values according to sorted shuffled keys. |
|
Wraps |
|
Performs octree convolution. |
|
alias of |
|
Convert OctreeConv modules to OctreeConvTriton modules in a network. |
- octree2voxel(data: Tensor, octree: Octree, depth: int, nempty: bool = False)[source]
Converts the input feature to the full-voxel-based representation.
- Parameters:
data (torch.Tensor) – The input feature.
octree (Octree) – The corresponding octree.
depth (int) – The depth of current octree.
nempty (bool) – If True,
dataonly contains the features of non-empty octree nodes.
- octree2col(data: Tensor, octree: Octree, depth: int, kernel_size: str = '333', stride: int = 1, nempty: bool = False)[source]
Gathers the neighboring features for convolutions.
- Parameters:
data (torch.Tensor) – The input data.
octree (Octree) – The corresponding octree.
depth (int) – The depth of current octree.
kernel_size (str) – The kernel shape, choose from
333,311,131,113,222,331,133, and313.stride (int) – The stride of neighborhoods (
1or2). If the stride is2, it always returns the neighborhood of the first siblings, and the number of elements of output tensor isoctree.nnum[depth] / 8.nempty (bool) – If True, only returns the neighborhoods of the non-empty octree nodes.
- col2octree(data: Tensor, octree: Octree, depth: int, kernel_size: str = '333', stride: int = 1, nempty: bool = False)[source]
Scatters the convolution features to an octree.
Please refer to
octree2col()for the usage of function parameters.
- octree_pad(data: Tensor, octree: Octree, depth: int, val: float = 0.0)[source]
Pads
valto make the number of elements ofdataequal to the octree node number.- Parameters:
data (torch.Tensor) – The input tensor with its number of elements equal to the non-empty octree node number.
octree (Octree) – The corresponding octree.
depth (int) – The depth of current octree.
val (float) – The padding value. (Default:
0.0)
- octree_depad(data: Tensor, octree: Octree, depth: int)[source]
Reverse operation of
octree_depad().Please refer to
octree_depad()for the meaning of the arguments.
- octree_nearest_pts(data: Tensor, octree: Octree, depth: int, pts: Tensor, nempty: bool = False, bound_check: bool = False)[source]
The nearest-neighbor interpolatation with input points.
- Parameters:
data (torch.Tensor) – The input data.
octree (Octree) – The octree to interpolate.
depth (int) – The depth of the data.
pts (torch.Tensor) – The coordinates of the points with shape
(N, 4), i.e.N x (x, y, z, batch).nempty (bool) – If true, the
dataonly contains features of non-empty octree nodesbound_check (bool) – If true, check whether the point is in
[0, 2^depth).
Note
The
ptsMUST be scaled into[0, 2^depth).
- octree_linear_pts(data: Tensor, octree: Octree, depth: int, pts: Tensor, nempty: bool = False, bound_check: bool = False)[source]
Linear interpolatation with input points.
Refer to
octree_nearest_pts()for the meaning of the arguments.
- octree_max_pool(data: Tensor, octree: Octree, depth: int, nempty: bool = False, return_indices: bool = False)[source]
Performs octree max pooling with kernel size 2 and stride 2.
- Parameters:
data (torch.Tensor) – The input tensor.
octree (Octree) – The corresponding octree.
depth (int) – The depth of current octree. After pooling, the corresponding depth decreased by 1.
nempty (bool) – If True,
datacontains only features of non-empty octree nodes.return_indices (bool) – If True, returns the indices, which can be used in
octree_max_unpool().
- octree_max_unpool(data: Tensor, indices: Tensor, octree: Octree, depth: int, nempty: bool = False)[source]
Performs octree max unpooling.
- Parameters:
data (torch.Tensor) – The input tensor.
indices (torch.Tensor) – The indices returned by
octree_max_pool(). The depth ofindicesis larger by 1 thandata.octree (Octree) – The corresponding octree.
depth (int) – The depth of current data. After unpooling, the corresponding depth increases by 1.
- octree_global_pool(data: Tensor, octree: Octree, depth: int, nempty: bool = False)[source]
Performs octree global average pooling.
- Parameters:
data (torch.Tensor) – The input tensor.
octree (Octree) – The corresponding octree.
depth (int) – The depth of current octree.
nempty (bool) – If True,
datacontains only features of non-empty octree nodes.
- octree_avg_pool(data: Tensor, octree: Octree, depth: int, kernel: str, stride: int = 2, nempty: bool = False)[source]
Performs octree average pooling.
- Parameters:
- class Octree2Voxel(nempty: bool = False)[source]
Converts the input feature to the full-voxel-based representation.
Please refer to
octree2voxel()for details.
- class OctreeMaxPool(nempty: bool = False, return_indices: bool = False)[source]
Performs octree max pooling.
Please refer to
octree_max_pool()for details.
- class OctreeMaxUnpool(nempty: bool = False)[source]
Performs octree max unpooling.
Please refer to
octree_max_unpool()for details.
- class OctreeGlobalPool(nempty: bool = False)[source]
Performs octree global pooling.
Please refer to
octree_global_pool()for details.
- class OctreeAvgPool(kernel_size: List[int], stride: int, nempty: bool = False)[source]
Performs octree average pooling.
Please refer to
octree_avg_pool()for details.
- class OctreeConv(in_channels: int, out_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, method: str = 'triton', use_bias: bool = False, max_buffer: int = 200000000)[source]
Performs octree convolution.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
kernel_size (List(int)) – The kernel shape, choose from
[3],[2],[3,3,3],[3,1,1],[1,3,1],[1,1,3],[2,2,2],[3,3,1],[1,3,3], and[3,1,3].stride (int) – The stride of the convolution (
1or2).nempty (bool) – If True, only performs the convolution on non-empty octree nodes.
method (str) – Which implementation to use. Options are
'explicit_gemm','block_gemm', and'triton'.'explicit_gemm'builds the full column matrix via octree2col/col2octree and then uses GEMM; this can use a large amount of memory.'block_gemm'computes in smaller blocks to reduce peak memory at some runtime cost.'triton'uses the implicit kernel and requires kernel_size=[3,3,3], stride=1, CUDA, and PyTorch >= 2.8.0.use_bias (bool) – If True, add a bias term to the convolution.
max_buffer (int) – The maximum number of elements in the buffer, used when
methodis ‘block_gemm’.
Note
Each non-empty octree node has exactly 8 children nodes, among which some children nodes are non-empty and some are empty. If
nemptyis true, the convolution is performed on non-empty octree nodes only, which is exactly the same as SparseConvNet and MinkowsiNet; ifnemptyis false, the convolution is performed on all octree nodes, which is essential for shape reconstruction tasks and can also be used in classification and segmentation (with slightly better performance and larger memory cost).- explicit_gemm(data: Tensor, octree: Octree, depth: int)[source]
Performs the convolution via explicitly constructing the col data.
- block_gemm(data: Tensor, octree: Octree, depth: int)[source]
Performs the convolution in a block manner, which can save the required runtime memory.
- class OctreeDeconv(in_channels: int, out_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, method: str = 'triton', use_bias: bool = False, max_buffer: int = 200000000)[source]
Performs octree deconvolution.
Please refer to
OctreeConvfor the meaning of the arguments.
- class OctreeGroupConv(in_channels: int, out_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, use_bias: bool = False, group: int = 1)[source]
Performs octree-based group convolution.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
kernel_size (List(int)) – The kernel shape, choose from
[3],[2],[3,3,3],[3,1,1],[1,3,1],[1,1,3],[2,2,2],[3,3,1],[1,3,3], and[3,1,3].stride (int) – The stride of the convolution (
1or2).nempty (bool) – If True, only performs the convolution on non-empty octree nodes.
use_bias (bool) – If True, add a bias term to the convolution.
group (int) – The number of groups.
Note
Perform octree-based group convolution with a for-loop. The performance is not optimal. Use this module only when the group number is small, otherwise it may be slow.
- class OctreeDWConv(in_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, use_bias: bool = False, max_buffer: int = 200000000)[source]
Performs octree-based depth-wise convolution.
Please refer to
ocnn.nn.OctreeConvfor the meaning of the arguments.Note
This implementation uses the
torch.einsum()and I find that the speed is relatively slow. Further optimization is needed to speed it up.
- class OctreeInterp(method: str = 'linear', nempty: bool = False, bound_check: bool = False, rescale_pts: bool = True)[source]
Interpolates the points with an octree feature.
Refer to
octree_nearest_pts()for a description of arguments.
- class OctreeUpsample(method: str = 'linear', nempty: bool = False)[source]
Upsamples the octree node features from
depthto(target_depth).Refer to
octree_nearest_ptsfor details.
- class OctreeInstanceNorm(in_channels: int, nempty: bool = False)[source]
An instance normalization layer for the octree.
- OctreeBatchNorm
alias of
BatchNorm1d
- class OctreeGroupNorm(in_channels: int, group: int, nempty: bool = False, min_group_channels: int = 4)[source]
An group normalization layer for the octree.
- class OctreeNorm(in_channels: int, norm_type: str = 'batch_norm', group: int = 32, min_group_channels: int = 4)[source]
A normalization layer for the octree. It encapsulates octree-based batch, group and instance normalization.
- forward(x: Tensor, octree: Octree | None = None, depth: int | None = None)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class OctreeDropPath(drop_prob: float = 0.0, nempty: bool = False, scale_by_keep: bool = True)[source]
Drop paths (Stochastic Depth) per octree when applied in main path of residual blocks.
- Parameters:
- forward(data: Tensor, octree: Octree, depth: int, batch_id: Tensor | None = None)[source]
Defines the drop path forward function.
- Parameters:
data (torch.Tensor) – The input features of shape (N, C), where N is the number of octree nodes and C is the number of channels.
octree (Octree) – The input octree.
depth (int) – The depth of the octree layer.
batch_id (torch.Tensor, optional) – The batch indices of the octree nodes. If not provided, it will be extracted from the octree.
- class DropPath(drop_prob: float = 0.0, scale_by_keep: bool = True, **kwargs)[source]
Drop paths (Stochastic Depth) per token when applied in main path of residual blocks, following the logic of
timm.models.layers.DropPath().- Parameters:
- forward(data: Tensor, **kwargs)[source]
Defines the drop path forward function.
- Parameters:
data (torch.Tensor) – The input features of shape (N, C), where N is the number of tokens and C is the number of channels.
- search_value(value: Tensor, key: Tensor, query: Tensor)[source]
Searches values according to sorted shuffled keys.
- Parameters:
value (torch.Tensor) – The input tensor with shape (N, C).
key (torch.Tensor) – The key tensor corresponds to
valuewith shape (N,), which contains sorted shuffled keys of an octree.query (torch.Tensor) – The query tensor, which also contains shuffled keys.
- octree_align(value: Tensor, octree: Octree, octree_query: Octree, depth: int, nempty: bool = False)[source]
Wraps
octree_align()to take octrees as input for convenience.
- class OctreeConvTriton(in_channels: int, out_channels: int, kernel_size: List[int] = [3], stride: int = 1, nempty: bool = False, method: str = 'triton', use_bias: bool = False, max_buffer: int = 200000000)[source]
Performs octree convolution.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
kernel_size (List(int)) – The kernel shape, only
[3]and[3,3,3]are supported now for the triton implementation.stride (int) – The stride of the convolution, only
1is supported now.nempty (bool) – If True, only performs the convolution on non-empty octree nodes; otherwise, performs the convolution on all octree nodes.
use_bias (bool) – If True, add a bias term to the convolution.
Note
Each non-empty octree node has exactly 8 children nodes, among which some children nodes are non-empty and some are empty. If
nemptyis true, the convolution is performed on non-empty octree nodes only, which is exactly the same as SparseConvNet and MinkowsiNet; ifnemptyis false, the convolution is performed on all octree nodes, which is essential for shape reconstruction tasks and can also be used in classification and segmentation (with slightly better performance and larger memory cost).
- OctreeConvT
alias of
OctreeConvTriton
- convert_conv_triton(module: Module) Module[source]
Convert OctreeConv modules to OctreeConvTriton modules in a network.
- Parameters:
module (torch.nn.Module) – The input module.