ocnn.models

LeNet

Octree-based LeNet for classification.

ResNet

Octree-based ResNet for classification.

SegNet

Octree-based SegNet for segmentation.

UNet

Octree-based UNet for segmentation.

HRNet

Octree-based HRNet for classification and segmentation.

AutoEncoder

Octree-based AutoEncoder for shape encoding and decoding.

OUNet

Image2Shape

Octree-based AutoEncoder for shape encoding and decoding.

class LeNet(in_channels: int, out_channels: int, stages: int, nempty: bool = False)[source]

Octree-based LeNet for classification.

forward(data: Tensor, octree: Octree, depth: int)[source]
class ResNet(in_channels: int, out_channels: int, resblock_num: int, stages: int, nempty: bool = False, dropout: float = 0.5)[source]

Octree-based ResNet for classification.

forward(data: Tensor, octree: Octree, depth: int)[source]
class SegNet(in_channels: int, out_channels: int, stages: int, interp: str = 'linear', nempty: bool = False, **kwargs)[source]

Octree-based SegNet for segmentation.

forward(data: Tensor, octree: Octree, depth: int, query_pts: Tensor)[source]
class UNet(in_channels: int, out_channels: int, interp: str = 'linear', nempty: bool = False, **kwargs)[source]

Octree-based UNet for segmentation.

config_network()[source]

Configure the network channels and Resblock numbers.

unet_encoder(data: Tensor, octree: Octree, depth: int)[source]

The encoder of the U-Net.

unet_decoder(convd: Dict[int, Tensor], octree: Octree, depth: int)[source]

The decoder of the U-Net.

forward(data: Tensor, octree: Octree, depth: int, query_pts: Tensor)[source]
class HRNet(in_channels: int, out_channels: int, stages: int = 3, interp: str = 'linear', nempty: bool = False)[source]

Octree-based HRNet for classification and segmentation.

forward(data: Tensor, octree: Octree, depth: int)[source]
class AutoEncoder(channel_in: int, channel_out: int, depth: int, full_depth: int = 2, feature: str = 'ND')[source]

Octree-based AutoEncoder for shape encoding and decoding.

Parameters:
  • channel_in (int) – The channel of the input signal.

  • channel_out (int) – The channel of the output signal.

  • depth (int) – The depth of the octree.

  • full_depth (int) – The full depth of the octree.

  • feature (str) – The feature type of the input signal. For details of this argument, please refer to ocnn.modules.InputFeature.

encoder(octree: Octree)[source]

The encoder network of the AutoEncoder.

decoder(shape_code: Tensor, octree: Octree, update_octree: bool = False)[source]

The decoder network of the AutoEncoder.

decode_code(shape_code: Tensor)[source]

Decodes the shape code to an output octree.

Parameters:

shape_code (torch.Tensor) – The shape code for decoding.

init_octree(shape_code: Tensor)[source]

Initialize a full octree for decoding.

Parameters:

shape_code (torch.Tensor) – The shape code for decoding, used to get the batch_size and device to initialize the output octree.

forward(octree: Octree, update_octree: bool)[source]
class OUNet(channel_in: int, channel_out: int, depth: int, full_depth: int = 2, feature: str = 'ND')[source]
encoder(octree)[source]

The encoder network for extracting heirarchy features.

decoder(convs: dict, octree_in: Octree, octree_out: Octree, update_octree: bool = False)[source]

The decoder network for decode the octree.

init_octree(octree_in: Octree)[source]

Initialize a full octree for decoding.

forward(octree_in, octree_out=None, update_octree: bool = False)[source]
class Image2Shape(channel_out: int, depth: int, full_depth: int = 2, code_channel: int = 32)[source]

Octree-based AutoEncoder for shape encoding and decoding.

Parameters:
  • channel_out (int) – The channel of the output signal.

  • depth (int) – The depth of the octree.

  • full_depth (int) – The full depth of the octree.

decoder(shape_code: Tensor, octree: Octree, update_octree: bool = False)[source]

The decoder network of the AutoEncoder.

decode_code(shape_code: Tensor)[source]

Decodes the shape code to an output octree.

Parameters:

shape_code (torch.Tensor) – The shape code for decoding.

init_octree(shape_code: Tensor)[source]

Initialize a full octree for decoding.

Parameters:

shape_code (torch.Tensor) – The shape code for decoding, used to getting the batch_size and device to initialize the output octree.

forward(image: Tensor, octree: Octree | None = None, update_octree: bool = False)[source]