Source code for ocnn.utils

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import math
import torch
from typing import Optional
from packaging import version

import ocnn


__all__ = ['trunc_div', 'meshgrid', 'cumsum', 'scatter_add', 'xavier_uniform_',
           'resize_with_last_val', 'list2str', 'build_example_octree']
classes = __all__


[docs]def trunc_div(input, other): r''' Wraps :func:`torch.div` for compatibility. It rounds the results of the division towards zero and is equivalent to C-style integer division. ''' larger_than_171 = version.parse(torch.__version__) > version.parse('1.7.1') if larger_than_171: return torch.div(input, other, rounding_mode='trunc') else: return torch.floor_divide(input, other)
[docs]def meshgrid(*tensors, indexing: Optional[str] = None): r''' Wraps :func:`torch.meshgrid` for compatibility. ''' larger_than_191 = version.parse(torch.__version__) > version.parse('1.9.1') if larger_than_191: return torch.meshgrid(*tensors, indexing=indexing) else: return torch.meshgrid(*tensors)
def range_grid(min: int, max: int, device: torch.device = 'cpu'): r''' Builds a 3D mesh grid in :obj:`[min, max]` (:attr:`max` included). Args: min (int): The minimum value of the grid. max (int): The maximum value of the grid. device (torch.device, optional): The device to place the grid on. Returns: torch.Tensor: A 3D mesh grid tensor of shape (N, 3), where N is the total number of grid points. Example: >>> grid = range_grid(0, 1) >>> print(grid) tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]) ''' rng = torch.arange(min, max+1, dtype=torch.long, device=device) grid = meshgrid(rng, rng, rng, indexing='ij') grid = torch.stack(grid, dim=-1).view(-1, 3) return grid
[docs]def cumsum(data: torch.Tensor, dim: int, exclusive: bool = False): r''' Extends :func:`torch.cumsum` with the input argument :attr:`exclusive`. Args: data (torch.Tensor): The input data. dim (int): The dimension to do the operation over. exclusive (bool): If false, the behavior is the same as :func:`torch.cumsum`; if true, returns the cumulative sum exclusively. Note that if ture, the shape of output tensor is larger by 1 than :attr:`data` in the dimension where the computation occurs. ''' out = torch.cumsum(data, dim) if exclusive: size = list(data.size()) size[dim] = 1 zeros = out.new_zeros(size) out = torch.cat([zeros, out], dim) return out
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): r''' Broadcast :attr:`src` according to :attr:`other`, originally from the library `pytorch_scatter`. ''' if dim < 0: dim = other.dim() + dim if src.dim() == 1: for _ in range(0, dim): src = src.unsqueeze(0) for _ in range(src.dim(), other.dim()): src = src.unsqueeze(-1) src = src.expand_as(other) return src
[docs]def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,) -> torch.Tensor: r''' Reduces all values from the :attr:`src` tensor into :attr:`out` at the indices specified in the :attr:`index` tensor along a given axis :attr:`dim`. This is just a wrapper of :func:`torch.scatter` in a boardcasting fashion. Args: src (torch.Tensor): The source tensor. index (torch.Tensor): The indices of elements to scatter. dim (int): The axis along which to index, (default: :obj:`-1`). out (torch.Tensor or None): The destination tensor. dim_size (int or None): If :attr:`out` is not given, automatically create output with size :attr:`dim_size` at dimension :attr:`dim`. If :attr:`dim_size` is not given, a minimal sized output tensor according to :obj:`index.max() + 1` is returned. ''' index = broadcast(index, src, dim) if out is None: size = list(src.size()) if dim_size is not None: size[dim] = dim_size elif index.numel() == 0: size[dim] = 0 else: size[dim] = int(index.max()) + 1 out = torch.zeros(size, dtype=src.dtype, device=src.device) return out.scatter_add_(dim, index, src)
[docs]def xavier_uniform_(weights: torch.Tensor): r''' Initialize convolution weights with the same method as :obj:`torch.nn.init.xavier_uniform_`. :obj:`torch.nn.init.xavier_uniform_` initialize a tensor with shape :obj:`(out_c, in_c, kdim)`, which can not be used in :class:`ocnn.nn.OctreeConv` since the the shape of :attr:`OctreeConv.weights` is :obj:`(kdim, in_c, out_c)`. ''' shape = weights.shape # (kernel_dim, in_conv, out_conv) fan_in = shape[0] * shape[1] fan_out = shape[0] * shape[2] std = math.sqrt(2.0 / float(fan_in + fan_out)) a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation torch.nn.init.uniform_(weights, -a, a)
[docs]def resize_with_last_val(list_in: list, num: int = 3): r''' Resizes the number of elements of :attr:`list_in` to :attr:`num` with the last element of :attr:`list_in` if its number of elements is smaller than :attr:`num`. ''' assert (type(list_in) is list and len(list_in) < num + 1) for _ in range(len(list_in), num): list_in.append(list_in[-1]) return list_in
[docs]def list2str(list_in: list): r''' Returns a string representation of :attr:`list_in`. ''' out = [str(x) for x in list_in] return ''.join(out)
[docs]def build_example_octree(depth: int = 5, full_depth: int = 2, pt_num: int = 3): r''' Builds an example octree on CPU from at most 3 points. ''' # initialize the point cloud points = torch.Tensor([[-1, -1, -1], [0, 0, -1], [0.0625, 0.0625, -1]]) normals = torch.Tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0]]) features = torch.Tensor([[1, -1], [2, -2], [3, -3]]) labels = torch.Tensor([[0], [2], [2]]) assert pt_num <= 3 and pt_num > 0 point_cloud = ocnn.octree.Points( points[:pt_num], normals[:pt_num], features[:pt_num], labels[:pt_num]) # build octree octree = ocnn.octree.Octree(depth, full_depth) octree.build_octree(point_cloud) return octree
def has_nan_inf(x: torch.Tensor, name: str): r''' Checks if all elements in :attr:`x` are finite for debugging. If not, raises a :obj:`RuntimeError` with the name of the tensor and its maximum absolute value. args: x (torch.Tensor): The tensor to check for finiteness. name (str): The name of the tensor, used in the error message if non-finite values are found. ''' if not torch.isfinite(x).all(): raise RuntimeError( f"{name} has NaN/Inf: dtype={x.dtype}, max={x.abs().max().item()}") def trilinear_interp_weights(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): r''' Perform efficient trilinear interpolation for N points in the unit cube [0, 1]. args: x (torch.Tensor): (N) tensor containing x coordinates of N sample points. y (torch.Tensor): (N) tensor containing y coordinates of N sample points. z (torch.Tensor): (N) tensor containing z coordinates of N sample points. Returns: (N, 8) tensor with interpolated results for the N sample points. ''' # 1. Compute base weights along each dimension. wx0, wx1 = 1 - x, x wy0, wy1 = 1 - y, y wz0, wz1 = 1 - z, z # 2. Compute combined weights for the 8 vertices and concatenate to (N, 8). # Index variation order: Z changes fastest, then Y, then X: 000, 001, 010, # 011, 100... weights = torch.stack([ wx0 * wy0 * wz0, # (0, 0, 0) wx0 * wy0 * wz1, # (0, 0, 1) wx0 * wy1 * wz0, # (0, 1, 0) wx0 * wy1 * wz1, # (0, 1, 1) wx1 * wy0 * wz0, # (1, 0, 0) wx1 * wy0 * wz1, # (1, 0, 1) wx1 * wy1 * wz0, # (1, 1, 0) wx1 * wy1 * wz1 # (1, 1, 1) ], dim=1) return weights