Source code for ocnn.octree.points

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch
import numpy as np
from typing import Optional, Union, List
from ocnn.utils import cumsum


[docs]class Points: r''' Represents a point cloud and contains some elementary transformations. Args: points (torch.Tensor): The coordinates of the points with a shape of :obj:`(N, 3)`, where :obj:`N` is the number of points. normals (torch.Tensor or None): The point normals with a shape of :obj:`(N, 3)`. features (torch.Tensor or None): The point features with a shape of :obj:`(N, C)`, where :obj:`C` is the channel of features. labels (torch.Tensor or None): The point labels with a shape of :obj:`(N, K)`, where :obj:`K` is the channel of labels. batch_id (torch.Tensor or None): The batch indices for each point with a shape of :obj:`(N, 1)`. batch_size (int): The batch size. ''' def __init__(self, points: torch.Tensor, normals: Optional[torch.Tensor] = None, features: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, batch_id: Optional[torch.Tensor] = None, batch_size: int = 1): super().__init__() self.points = points self.normals = normals self.features = features self.labels = labels self.batch_id = batch_id self.batch_size = batch_size self.device = points.device self.batch_npt = None # valid after `merge_points` self.check_input()
[docs] def check_input(self): r''' Checks the input arguments. ''' assert self.points.dim() == 2 and self.points.size(1) == 3 if self.normals is not None: assert self.normals.dim() == 2 and self.normals.size(1) == 3 assert self.normals.size(0) == self.points.size(0) if self.features is not None: assert self.features.dim() == 2 assert self.features.size(0) == self.points.size(0) if self.labels is not None: assert self.labels.dim() == 2 or self.labels.dim() == 1 assert self.labels.size(0) == self.points.size(0) if self.labels.dim() == 1: self.labels = self.labels.unsqueeze(1) if self.batch_id is not None: assert self.batch_id.dim() == 2 or self.batch_id.dim() == 1 assert self.batch_id.size(0) == self.points.size(0) if self.batch_id.dim() == 1: self.batch_id = self.batch_id.unsqueeze(1) assert self.batch_id.size(1) == 1 assert self.batch_size == self.batch_id.max().item() + 1
@property def npt(self): return self.points.shape[0]
[docs] def orient_normal(self, axis: str = 'x'): r''' Orients the point normals along a given axis. Args: axis (int): The coordinate axes, choose from :obj:`x`, :obj:`y` and :obj:`z`. (default: :obj:`x`) ''' if self.normals is None: return axis_map = {'x': 0, 'y': 1, 'z': 2, 'xyz': 3} idx = axis_map[axis] if idx < 3: flags = self.normals[:, idx] > 0 flags = flags.float() * 2.0 - 1.0 # [0, 1] -> [-1, 1] self.normals = self.normals * flags.unsqueeze(1) else: self.normals.abs_()
[docs] def scale(self, factor: torch.Tensor): r''' Rescales the point cloud. Args: factor (torch.Tensor): The scale factor with shape :obj:`(3,)`. ''' non_zero = (factor != 0).all() all_ones = (factor == 1.0).all() non_uniform = (factor != factor[0]).any() assert non_zero, 'The scale factor must not constain 0.' if all_ones: return factor = factor.to(self.device) self.points = self.points * factor if self.normals is not None and non_uniform: ifactor = 1.0 / factor self.normals = self.normals * ifactor norm2 = torch.sqrt(torch.sum(self.normals ** 2, dim=1, keepdim=True)) self.normals = self.normals / torch.clamp(norm2, min=1.0e-12)
[docs] def rotate(self, angle: torch.Tensor): r''' Rotates the point cloud. Args: angle (torch.Tensor): The rotation angles in radian with shape :obj:`(3,)`. ''' cos, sin = angle.cos(), angle.sin() # rotx, roty, rotz are actually the transpose of the rotation matrices rotx = torch.Tensor([[1, 0, 0], [0, cos[0], sin[0]], [0, -sin[0], cos[0]]]) roty = torch.Tensor([[cos[1], 0, -sin[1]], [0, 1, 0], [sin[1], 0, cos[1]]]) rotz = torch.Tensor([[cos[2], sin[2], 0], [-sin[2], cos[2], 0], [0, 0, 1]]) rot = rotx @ roty @ rotz rot = rot.to(self.device) self.points = self.points @ rot if self.normals is not None: self.normals = self.normals @ rot
[docs] def translate(self, dis: torch.Tensor): r''' Translates the point cloud. Args: dis (torch.Tensor): The displacement with shape :obj:`(3,)`. ''' dis = dis.to(self.device) self.points = self.points + dis
[docs] def flip(self, axis: str): r''' Flips the point cloud along the given :attr:`axis`. Args: axis (str): The flipping axis, choosen from :obj:`x`, :obj:`y`, and :obj`z`. ''' axis_map = {'x': 0, 'y': 1, 'z': 2} for x in axis: idx = axis_map[x] self.points[:, idx] *= -1.0 if self.normals is not None: self.normals[:, idx] *= -1.0
[docs] def clip(self, min: float = -1.0, max: float = 1.0, esp: float = 0.01): r''' Clips the point cloud to :obj:`[min+esp, max-esp]` and returns the mask. Args: min (float): The minimum value to clip. max (float): The maximum value to clip. esp (float): The margin. ''' mask = self.inbox_mask(min + esp, max - esp) self.copy_from(self[mask]) return mask
def __getitem__(self, idx): r''' Slices the point cloud according a given :attr:`idx`. ''' out = self.init_points(self.device, self.batch_size) out.points = self.points[idx] if self.normals is not None: out.normals = self.normals[idx] if self.features is not None: out.features = self.features[idx] if self.labels is not None: out.labels = self.labels[idx] if self.batch_id is not None: out.batch_id = self.batch_id[idx] return out
[docs] def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0, bbmax: Union[float, torch.Tensor] = 1.0): r''' Returns a mask indicating whether the points are within the specified bounding box or not. ''' mask_min = torch.all(self.points > bbmin, dim=1) mask_max = torch.all(self.points < bbmax, dim=1) mask = torch.logical_and(mask_min, mask_max) return mask
[docs] def bbox(self): r''' Returns the bounding box. ''' # torch.min and torch.max return (value, indices) bbmin = self.points.min(dim=0) bbmax = self.points.max(dim=0) return bbmin[0], bbmax[0]
[docs] def centralize_scale(self, bbmin: torch.Tensor, bbmax: torch.Tensor, scale: float = 1.0): r''' Centralizes the point cloud to :obj:`[-scale, scale]`. Args: bbmin (torch.Tensor): The minimum coordinates of the bounding box. bbmax (torch.Tensor): The maximum coordinates of the bounding box. scale (float): The scale factor ''' center = (bbmin + bbmax) * 0.5 box_size = (bbmax - bbmin).max() + 1.0e-6 self.points = (self.points - center) * (2.0 * scale / box_size)
[docs] def normalize(self, bbmin: torch.Tensor, bbmax: torch.Tensor, scale: float, inplace: bool = False): r''' Normalizes the point cloud to :obj:`[0, scale]`. Args: bbmin (torch.Tensor): The minimum coordinates of the bounding box. bbmax (torch.Tensor): The maximum coordinates of the bounding box. scale (float): The scale factor. inplace (bool): If True, the normalization is performed in-place; otherwise, directly returns the normalized points without modifying the original points. ''' box_size = bbmax - bbmin if type(box_size) is torch.Tensor: box_size = box_size.max().item() assert box_size > 0, 'The bounding box size must be greater than 0.' points = (self.points - bbmin) * (scale / box_size) if inplace: self.points = points return points
[docs] def to(self, device: Union[torch.device, str], non_blocking: bool = False): r''' Moves the Points to a specified device. Args: device (torch.device or str): The destination device. non_blocking (bool): If True and the source is in pinned memory, the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False. ''' if isinstance(device, str): device = torch.device(device) # If on the same device, directly return self if self.device == device: return self # Construct a new Points on the specified device points = self.init_points(device, self.batch_size) points.batch_npt = self.batch_npt points.points = self.points.to(device, non_blocking=non_blocking) if self.normals is not None: points.normals = self.normals.to(device, non_blocking=non_blocking) if self.features is not None: points.features = self.features.to(device, non_blocking=non_blocking) if self.labels is not None: points.labels = self.labels.to(device, non_blocking=non_blocking) if self.batch_id is not None: points.batch_id = self.batch_id.to(device, non_blocking=non_blocking) return points
[docs] def cuda(self, non_blocking: bool = False): r''' Moves the Points to the GPU. ''' return self.to('cuda', non_blocking)
[docs] def cpu(self): r''' Moves the Points to the CPU. ''' return self.to('cpu')
[docs] def save(self, filename: str, info: str = 'PNFL'): r''' Save the Points into npz or xyz files. Args: filename (str): The output filename. info (str): The infomation for saving: 'P' -> 'points', 'N' -> 'normals', 'F' -> 'features', 'L' -> 'labels', 'B' -> 'batch_id'. ''' mapping = { 'P': ('points', self.points), 'N': ('normals', self.normals), 'F': ('features', self.features), 'L': ('labels', self.labels), 'B': ('batch_id', self.batch_id), } names, outs = [], [] for key in info.upper(): name, out = mapping[key] if out is not None: names.append(name) if out.dim() == 1: out = out.unsqueeze(1) outs.append(out.cpu().numpy()) if filename.endswith('npz'): out_dict = dict(zip(names, outs)) np.savez(filename, **out_dict) elif filename.endswith('xyz'): out_array = np.concatenate(outs, axis=1) np.savetxt(filename, out_array, fmt='%.6f') else: raise ValueError
[docs] def copy_from(self, points: 'Points'): r''' Shallow copy from another Points. ''' self.points = points.points self.normals = points.normals self.features = points.features self.labels = points.labels self.batch_id = points.batch_id self.batch_size = points.batch_size self.device = points.device self.batch_npt = points.batch_npt
[docs] def merge_points(self, points: List['Points'], update_batch_info: bool = True): r''' Merges a list of points into one batch. Args: points (List[Octree]): A list of points to merge. The batch size of each points in the list is assumed to be 1, and the :obj:`batch_size`, :obj:`batch_id`, and :obj:`batch_npt` in the points are ignored. ''' self.points = torch.cat([p.points for p in points], dim=0) if points[0].normals is not None: self.normals = torch.cat([p.normals for p in points], dim=0) if points[0].features is not None: self.features = torch.cat([p.features for p in points], dim=0) if points[0].labels is not None: self.labels = torch.cat([p.labels for p in points], dim=0) self.device = points[0].device if update_batch_info: self.batch_size = len(points) self.batch_npt = torch.Tensor([p.npt for p in points]).long() self.batch_id = torch.cat([p.points.new_full((p.npt, 1), i) for i, p in enumerate(points)], dim=0) return self
[docs] def split_points(self): r''' Splits the batched points into a list of Points. ''' if self.batch_npt is None: self.batch_npt = torch.bincount( self.batch_id.squeeze(), minlength=self.batch_size) outs = [] cs = cumsum(self.batch_npt, dim=0, exclusive=True) for i in range(self.batch_size): rng = range(cs[i], cs[i+1]) out = Points.init_points(self.device, batch_size=1) out.points = self.points[rng] if self.normals is not None: out.normals = self.normals[rng] if self.features is not None: out.features = self.features[rng] if self.labels is not None: out.labels = self.labels[rng] outs.append(out) return outs
[docs] @classmethod def init_points(cls, device: Union[torch.device, str, None] = None, batch_size: int = 1): r''' Initialzes a Points object with dummy data on a specified device. Args: device (torch.device or str or None): The device of the Points. If :obj:`None`, the device is set to :obj:`cpu`. batch_size (int): The batch size. ''' points = torch.zeros(batch_size, 3, device=device) batch_id = (torch.arange(batch_size, device=device).unsqueeze(1) if batch_size > 1 else None) return cls(points, batch_size=batch_size, batch_id=batch_id)
[docs]def merge_points(points: List['Points'], update_batch_info: bool = True): r''' A wrapper of :meth:`Points.merge_points`. .. deprecated:: 2.2.7 Use :meth:`Points.merge_points` instead. ''' assert len(points) > 0, 'The input points list is empty.' out = Points.init_points(points[0].device, batch_size=len(points)) out.merge_points(points, update_batch_info) return out