Source code for ocnn.dataset

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch

import ocnn
from ocnn.octree import Octree, Points


__all__ = ['Transform', 'CollateBatch']
classes = __all__


[docs]class Transform: r''' A boilerplate class which transforms an input data for :obj:`ocnn`. The input data is first converted to :class:`Points`, then randomly transformed (if enabled), and converted to an :class:`Octree`. Args: depth (int): The octree depth. full_depth (int): The octree layers with a depth small than :attr:`full_depth` are forced to be full. distort (bool): If true, performs the data augmentation. angle (list): A list of 3 float values to generate random rotation angles. interval (list): A list of 3 float values to represent the interval of rotation angles. scale (float): The maximum relative scale factor. uniform (bool): If true, performs uniform scaling. jitter (float): The maximum jitter values. flip (list): A list of 3 float values to represent the probability of flipping each axis. orient_normal (str): Orient point normals along the specified axis, which is useful when normals are not oriented. ''' def __init__(self, depth: int, full_depth: int, distort: bool, angle: list, interval: list, scale: float, uniform: bool, jitter: float, flip: list, orient_normal: str = '', **kwargs): super().__init__() # for octree building self.depth = depth self.full_depth = full_depth # for data augmentation self.distort = distort self.angle = angle self.interval = interval self.scale = scale self.uniform = uniform self.jitter = jitter self.flip = flip # for other transformations self.orient_normal = orient_normal def __call__(self, sample: dict, idx: int): r'''''' output = self.preprocess(sample, idx) output = self.transform(output, idx) output['octree'] = self.points2octree(output['points']) return output
[docs] def preprocess(self, sample: dict, idx: int): r''' Transforms :attr:`sample` to :class:`Points` and performs some specific transformations, like normalization. ''' xyz = torch.from_numpy(sample.pop('points')) normals = sample.pop('normals', None) normals = torch.from_numpy(normals) if normals is not None else None sample['points'] = Points(xyz, normals) return sample
[docs] def transform(self, sample: dict, idx: int): r''' Applies the general transformations provided by :obj:`ocnn`. ''' # The augmentations including rotation, scaling, and jittering. points = sample['points'] if self.distort: rng_angle, rng_scale, rng_jitter, rnd_flip = self.rnd_parameters() points.flip(rnd_flip) points.rotate(rng_angle) points.translate(rng_jitter) points.scale(rng_scale) if self.orient_normal: points.orient_normal(self.orient_normal) # !!! NOTE: Clip the point cloud to [-1, 1] before building the octree inbox_mask = points.clip(min=-1, max=1) sample.update({'points': points, 'inbox_mask': inbox_mask}) return sample
[docs] def points2octree(self, points: Points): r''' Converts the input :attr:`points` to an octree. ''' octree = Octree(self.depth, self.full_depth) octree.build_octree(points) return octree
[docs] def rnd_parameters(self): r''' Generates random parameters for data augmentation. ''' rnd_angle = [None] * 3 for i in range(3): rot_num = self.angle[i] // self.interval[i] rnd = torch.randint(low=-rot_num, high=rot_num+1, size=(1,)) rnd_angle[i] = rnd * self.interval[i] * (3.14159265 / 180.0) rnd_angle = torch.cat(rnd_angle) rnd_scale = torch.rand(3) * (2 * self.scale) - self.scale + 1.0 if self.uniform: rnd_scale[1] = rnd_scale[0] rnd_scale[2] = rnd_scale[0] rnd_flip = '' for i, c in enumerate('xyz'): if torch.rand([1]) < self.flip[i]: rnd_flip = rnd_flip + c rnd_jitter = torch.rand(3) * (2 * self.jitter) - self.jitter return rnd_angle, rnd_scale, rnd_jitter, rnd_flip
[docs]class CollateBatch: r''' Merge a list of octrees and points into a batch. ''' def __init__(self, merge_points: bool = False): self.merge_points = merge_points def __call__(self, batch: list): assert type(batch) == list outputs = {} for key in batch[0].keys(): outputs[key] = [b[key] for b in batch] # Merge a batch of octrees into one super octree if 'octree' in key: octree = ocnn.octree.merge_octrees(outputs[key]) # NOTE: remember to construct the neighbor indices octree.construct_all_neigh() outputs[key] = octree # Merge a batch of points if 'points' in key and self.merge_points: outputs[key] = ocnn.octree.merge_points(outputs[key]) # Convert the labels to a Tensor if 'label' in key: outputs['label'] = torch.tensor(outputs[key]) return outputs