# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import torch
import numpy as np
from typing import Optional, Union, List
from ocnn.utils import cumsum
[docs]class Points:
r''' Represents a point cloud and contains some elementary transformations.
Args:
points (torch.Tensor): The coordinates of the points with a shape of
:obj:`(N, 3)`, where :obj:`N` is the number of points.
normals (torch.Tensor or None): The point normals with a shape of
:obj:`(N, 3)`.
features (torch.Tensor or None): The point features with a shape of
:obj:`(N, C)`, where :obj:`C` is the channel of features.
labels (torch.Tensor or None): The point labels with a shape of
:obj:`(N, K)`, where :obj:`K` is the channel of labels.
batch_id (torch.Tensor or None): The batch indices for each point with a
shape of :obj:`(N, 1)`.
batch_size (int): The batch size.
'''
def __init__(self, points: torch.Tensor,
normals: Optional[torch.Tensor] = None,
features: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
batch_id: Optional[torch.Tensor] = None,
batch_size: int = 1):
super().__init__()
self.points = points
self.normals = normals
self.features = features
self.labels = labels
self.batch_id = batch_id
self.batch_size = batch_size
self.device = points.device
self.batch_npt = None # valid after `merge_points`
self.check_input()
@property
def npt(self):
return self.points.shape[0]
[docs] def orient_normal(self, axis: str = 'x'):
r''' Orients the point normals along a given axis.
Args:
axis (int): The coordinate axes, choose from :obj:`x`, :obj:`y` and
:obj:`z`. (default: :obj:`x`)
'''
if self.normals is None:
return
axis_map = {'x': 0, 'y': 1, 'z': 2, 'xyz': 3}
idx = axis_map[axis]
if idx < 3:
flags = self.normals[:, idx] > 0
flags = flags.float() * 2.0 - 1.0 # [0, 1] -> [-1, 1]
self.normals = self.normals * flags.unsqueeze(1)
else:
self.normals.abs_()
[docs] def scale(self, factor: torch.Tensor):
r''' Rescales the point cloud.
Args:
factor (torch.Tensor): The scale factor with shape :obj:`(3,)`.
'''
non_zero = (factor != 0).all()
all_ones = (factor == 1.0).all()
non_uniform = (factor != factor[0]).any()
assert non_zero, 'The scale factor must not constain 0.'
if all_ones: return
factor = factor.to(self.device)
self.points = self.points * factor
if self.normals is not None and non_uniform:
ifactor = 1.0 / factor
self.normals = self.normals * ifactor
norm2 = torch.sqrt(torch.sum(self.normals ** 2, dim=1, keepdim=True))
self.normals = self.normals / torch.clamp(norm2, min=1.0e-12)
[docs] def rotate(self, angle: torch.Tensor):
r''' Rotates the point cloud.
Args:
angle (torch.Tensor): The rotation angles in radian with shape :obj:`(3,)`.
'''
cos, sin = angle.cos(), angle.sin()
# rotx, roty, rotz are actually the transpose of the rotation matrices
rotx = torch.Tensor([[1, 0, 0], [0, cos[0], sin[0]], [0, -sin[0], cos[0]]])
roty = torch.Tensor([[cos[1], 0, -sin[1]], [0, 1, 0], [sin[1], 0, cos[1]]])
rotz = torch.Tensor([[cos[2], sin[2], 0], [-sin[2], cos[2], 0], [0, 0, 1]])
rot = rotx @ roty @ rotz
rot = rot.to(self.device)
self.points = self.points @ rot
if self.normals is not None:
self.normals = self.normals @ rot
[docs] def translate(self, dis: torch.Tensor):
r''' Translates the point cloud.
Args:
dis (torch.Tensor): The displacement with shape :obj:`(3,)`.
'''
dis = dis.to(self.device)
self.points = self.points + dis
[docs] def flip(self, axis: str):
r''' Flips the point cloud along the given :attr:`axis`.
Args:
axis (str): The flipping axis, choosen from :obj:`x`, :obj:`y`, and :obj`z`.
'''
axis_map = {'x': 0, 'y': 1, 'z': 2}
for x in axis:
idx = axis_map[x]
self.points[:, idx] *= -1.0
if self.normals is not None:
self.normals[:, idx] *= -1.0
[docs] def clip(self, min: float = -1.0, max: float = 1.0, esp: float = 0.01):
r''' Clips the point cloud to :obj:`[min+esp, max-esp]` and returns the mask.
Args:
min (float): The minimum value to clip.
max (float): The maximum value to clip.
esp (float): The margin.
'''
mask = self.inbox_mask(min + esp, max - esp)
self.copy_from(self[mask])
return mask
def __getitem__(self, idx):
r''' Slices the point cloud according a given :attr:`idx`.
'''
out = self.init_points(self.device, self.batch_size)
out.points = self.points[idx]
if self.normals is not None:
out.normals = self.normals[idx]
if self.features is not None:
out.features = self.features[idx]
if self.labels is not None:
out.labels = self.labels[idx]
if self.batch_id is not None:
out.batch_id = self.batch_id[idx]
return out
[docs] def inbox_mask(self, bbmin: Union[float, torch.Tensor] = -1.0,
bbmax: Union[float, torch.Tensor] = 1.0):
r''' Returns a mask indicating whether the points are within the specified
bounding box or not.
'''
mask_min = torch.all(self.points > bbmin, dim=1)
mask_max = torch.all(self.points < bbmax, dim=1)
mask = torch.logical_and(mask_min, mask_max)
return mask
[docs] def bbox(self):
r''' Returns the bounding box.
'''
# torch.min and torch.max return (value, indices)
bbmin = self.points.min(dim=0)
bbmax = self.points.max(dim=0)
return bbmin[0], bbmax[0]
[docs] def centralize_scale(self, bbmin: torch.Tensor, bbmax: torch.Tensor,
scale: float = 1.0):
r''' Centralizes the point cloud to :obj:`[-scale, scale]`.
Args:
bbmin (torch.Tensor): The minimum coordinates of the bounding box.
bbmax (torch.Tensor): The maximum coordinates of the bounding box.
scale (float): The scale factor
'''
center = (bbmin + bbmax) * 0.5
box_size = (bbmax - bbmin).max() + 1.0e-6
self.points = (self.points - center) * (2.0 * scale / box_size)
[docs] def normalize(self, bbmin: torch.Tensor, bbmax: torch.Tensor,
scale: float, inplace: bool = False):
r''' Normalizes the point cloud to :obj:`[0, scale]`.
Args:
bbmin (torch.Tensor): The minimum coordinates of the bounding box.
bbmax (torch.Tensor): The maximum coordinates of the bounding box.
scale (float): The scale factor.
inplace (bool): If True, the normalization is performed in-place;
otherwise, directly returns the normalized points without modifying
the original points.
'''
box_size = bbmax - bbmin
if type(box_size) is torch.Tensor:
box_size = box_size.max().item()
assert box_size > 0, 'The bounding box size must be greater than 0.'
points = (self.points - bbmin) * (scale / box_size)
if inplace:
self.points = points
return points
[docs] def to(self, device: Union[torch.device, str], non_blocking: bool = False):
r''' Moves the Points to a specified device.
Args:
device (torch.device or str): The destination device.
non_blocking (bool): If True and the source is in pinned memory, the copy
will be asynchronous with respect to the host. Otherwise, the argument
has no effect. Default: False.
'''
if isinstance(device, str):
device = torch.device(device)
# If on the same device, directly return self
if self.device == device:
return self
# Construct a new Points on the specified device
points = self.init_points(device, self.batch_size)
points.batch_npt = self.batch_npt
points.points = self.points.to(device, non_blocking=non_blocking)
if self.normals is not None:
points.normals = self.normals.to(device, non_blocking=non_blocking)
if self.features is not None:
points.features = self.features.to(device, non_blocking=non_blocking)
if self.labels is not None:
points.labels = self.labels.to(device, non_blocking=non_blocking)
if self.batch_id is not None:
points.batch_id = self.batch_id.to(device, non_blocking=non_blocking)
return points
[docs] def cuda(self, non_blocking: bool = False):
r''' Moves the Points to the GPU. '''
return self.to('cuda', non_blocking)
[docs] def cpu(self):
r''' Moves the Points to the CPU. '''
return self.to('cpu')
[docs] def save(self, filename: str, info: str = 'PNFL'):
r''' Save the Points into npz or xyz files.
Args:
filename (str): The output filename.
info (str): The infomation for saving: 'P' -> 'points', 'N' -> 'normals',
'F' -> 'features', 'L' -> 'labels', 'B' -> 'batch_id'.
'''
mapping = {
'P': ('points', self.points), 'N': ('normals', self.normals),
'F': ('features', self.features), 'L': ('labels', self.labels),
'B': ('batch_id', self.batch_id), }
names, outs = [], []
for key in info.upper():
name, out = mapping[key]
if out is not None:
names.append(name)
if out.dim() == 1:
out = out.unsqueeze(1)
outs.append(out.cpu().numpy())
if filename.endswith('npz'):
out_dict = dict(zip(names, outs))
np.savez(filename, **out_dict)
elif filename.endswith('xyz'):
out_array = np.concatenate(outs, axis=1)
np.savetxt(filename, out_array, fmt='%.6f')
else:
raise ValueError
[docs] def copy_from(self, points: 'Points'):
r''' Shallow copy from another Points.
'''
self.points = points.points
self.normals = points.normals
self.features = points.features
self.labels = points.labels
self.batch_id = points.batch_id
self.batch_size = points.batch_size
self.device = points.device
self.batch_npt = points.batch_npt
[docs] def merge_points(self, points: List['Points'], update_batch_info: bool = True):
r''' Merges a list of points into one batch.
Args:
points (List[Octree]): A list of points to merge. The batch size of each
points in the list is assumed to be 1, and the :obj:`batch_size`,
:obj:`batch_id`, and :obj:`batch_npt` in the points are ignored.
'''
self.points = torch.cat([p.points for p in points], dim=0)
if points[0].normals is not None:
self.normals = torch.cat([p.normals for p in points], dim=0)
if points[0].features is not None:
self.features = torch.cat([p.features for p in points], dim=0)
if points[0].labels is not None:
self.labels = torch.cat([p.labels for p in points], dim=0)
self.device = points[0].device
if update_batch_info:
self.batch_size = len(points)
self.batch_npt = torch.Tensor([p.npt for p in points]).long()
self.batch_id = torch.cat([p.points.new_full((p.npt, 1), i)
for i, p in enumerate(points)], dim=0)
return self
[docs] def split_points(self):
r''' Splits the batched points into a list of Points.
'''
if self.batch_npt is None:
self.batch_npt = torch.bincount(
self.batch_id.squeeze(), minlength=self.batch_size)
outs = []
cs = cumsum(self.batch_npt, dim=0, exclusive=True)
for i in range(self.batch_size):
rng = range(cs[i], cs[i+1])
out = Points.init_points(self.device, batch_size=1)
out.points = self.points[rng]
if self.normals is not None:
out.normals = self.normals[rng]
if self.features is not None:
out.features = self.features[rng]
if self.labels is not None:
out.labels = self.labels[rng]
outs.append(out)
return outs
[docs] @classmethod
def init_points(cls, device: Union[torch.device, str, None] = None,
batch_size: int = 1):
r''' Initialzes a Points object with dummy data on a specified device.
Args:
device (torch.device or str or None): The device of the Points. If
:obj:`None`, the device is set to :obj:`cpu`.
batch_size (int): The batch size.
'''
points = torch.zeros(batch_size, 3, device=device)
batch_id = (torch.arange(batch_size, device=device).unsqueeze(1)
if batch_size > 1 else None)
return cls(points, batch_size=batch_size, batch_id=batch_id)
[docs]def merge_points(points: List['Points'], update_batch_info: bool = True):
r''' A wrapper of :meth:`Points.merge_points`.
.. deprecated:: 2.2.7
Use :meth:`Points.merge_points` instead.
'''
assert len(points) > 0, 'The input points list is empty.'
out = Points.init_points(points[0].device, batch_size=len(points))
out.merge_points(points, update_batch_info)
return out