diff --git a/model_examples/PanoOcc/migrate_to_ascend/dist_train.sh b/model_examples/PanoOcc/migrate_to_ascend/dist_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..35e7be6e972087e9685ef7bdb64a1e02ae314b12 --- /dev/null +++ b/model_examples/PanoOcc/migrate_to_ascend/dist_train.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# +CONFIG=$1 +GPUS=$2 +WORK_DIR=$3 +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +PORT=${PORT:-29500} +MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +torchrun \ + --nnodes=$NNODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --nproc_per_node=$GPUS \ + --master_port=$PORT \ + $(dirname "$0")/train.py \ + $CONFIG \ + --seed 0 \ + --work-dir ${WORK_DIR} \ + --launcher pytorch ${@:4} --deterministic \ No newline at end of file diff --git a/model_examples/PanoOcc/mmdetection3d.patch b/model_examples/PanoOcc/migrate_to_ascend/mmdetection3d.patch similarity index 100% rename from model_examples/PanoOcc/mmdetection3d.patch rename to model_examples/PanoOcc/migrate_to_ascend/mmdetection3d.patch diff --git a/model_examples/PanoOcc/migrate_to_ascend/panoseg_occ_head_patch.py b/model_examples/PanoOcc/migrate_to_ascend/panoseg_occ_head_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..acb32c809d6057faabc16c42e582f3ec69edd967 --- /dev/null +++ b/model_examples/PanoOcc/migrate_to_ascend/panoseg_occ_head_patch.py @@ -0,0 +1,740 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +import importlib +from types import ModuleType +from typing import Dict, Optional, Union + +import torch +import mx_driving + +def custom_unique_n3(coors, return_inverse, return_counts, dim): + # assert dim == 0 + + voxels = mx_driving._C.point_to_voxel(coors, [], [], "ZYX") + cnt, unq_voxels, unq_ind, argsort_ind, _ = mx_driving._C.unique_voxel(voxels) + unq_coors = mx_driving._C.voxel_to_point(unq_voxels, [], [], "ZYX") + + if return_inverse: + sorted_ind = torch.argsort(argsort_ind.to(torch.float32), dim=dim).to(torch.long) + is_unq = torch.zeros(coors.size(0)).to(coors.device) + is_unq[unq_ind] = 1 + unq_inv_sorted = is_unq.cumsum(dim) - 1 + unq_inv = torch.gather(unq_inv_sorted, dim, sorted_ind) + unq_inv = unq_inv.to(torch.int64) + + if return_counts: + unq_ind_nxt = torch.ones_like(unq_ind) * coors.size(0) + unq_ind_nxt[:-1] = unq_ind[1:] + unq_cnt = unq_ind_nxt - unq_ind + unq_cnt = unq_cnt.to(torch.int64) + + if return_inverse and return_counts: + return unq_coors, unq_inv, unq_cnt + elif return_inverse: + return unq_coors, unq_inv + elif return_counts: + return unq_coors, unq_cnt + else: + return unq_coors + +def panoseg_occ_head_patch(panoseg_occ_head_module: ModuleType, options: Dict): + import copy + from mmdet.models import HEADS + from mmdet.models.dense_heads import DETRHead + from mmdet3d.core.bbox.coders import build_bbox_coder + import torch + import torch.nn as nn + import torch.nn.functional as F + from mmdet.models.builder import build_loss + from mmcv.cnn import Linear, bias_init_with_prob + from mmcv.runner import force_fp32, auto_fp16 + from mmdet.models.utils.transformer import inverse_sigmoid + from mmdet.core import (multi_apply, multi_apply, reduce_mean) + from projects.mmdet3d_plugin.core.bbox.util import normalize_bbox + from mmcv.utils import TORCH_VERSION, digit_version + + @HEADS.register_module(force=True) + class PanoSegOccHead(DETRHead): + def __init__(self, + *args, + with_box_refine=False, + as_two_stage=False, + transformer=None, + bbox_coder=None, + num_cls_fcs=2, + code_weights=None, + bev_h=30, + bev_w=30, + bev_z=5, + voxel_lidar=[0.05, 0.05, 0.05], + voxel_det=[2.048,2.048,1], + loss_occupancy=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=5.0), + loss_occupancy_aux=None, + loss_occupancy_det=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=5.0), + bg_weight=0.02, + **kwargs): + + self.bev_h = bev_h + self.bev_w = bev_w + self.bev_z = bev_z + self.voxel_lidar = voxel_lidar + self.voxel_det = voxel_det + self.fp16_enabled = False + self.bg_weight = bg_weight + + self.with_box_refine = with_box_refine + self.as_two_stage = as_two_stage + if self.as_two_stage: + transformer['as_two_stage'] = self.as_two_stage + if 'code_size' in kwargs: + self.code_size = kwargs['code_size'] + else: + self.code_size = 10 + if code_weights is not None: + self.code_weights = code_weights + else: + self.code_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2] + + self.bbox_coder = build_bbox_coder(bbox_coder) + self.pc_range = self.bbox_coder.pc_range + self.real_w = self.pc_range[3] - self.pc_range[0] + self.real_h = self.pc_range[4] - self.pc_range[1] + self.num_cls_fcs = num_cls_fcs - 1 + super(PanoSegOccHead, self).__init__( + *args, transformer=transformer, **kwargs) + self.code_weights = nn.Parameter(torch.tensor( + self.code_weights, requires_grad=False), requires_grad=False) + self.lidar_seg_loss = build_loss(loss_occupancy) + self.lidar_det_loss = build_loss(loss_occupancy_det) + if loss_occupancy_aux is not None: + self.lidar_seg_aux_loss = build_loss(loss_occupancy_aux) + + self.pc_range = nn.Parameter(torch.tensor( + self.pc_range, requires_grad=False), requires_grad=False) + self.voxel_lidar = nn.Parameter(torch.tensor( + self.voxel_lidar, requires_grad=False), requires_grad=False) + self.voxel_det = nn.Parameter(torch.tensor( + self.voxel_det, requires_grad=False), requires_grad=False) + + def _init_layers(self): + """Initialize classification branch and regression branch of head.""" + cls_branch = [] + for _ in range(self.num_reg_fcs): + cls_branch.append(Linear(self.embed_dims, self.embed_dims)) + cls_branch.append(nn.LayerNorm(self.embed_dims)) + cls_branch.append(nn.ReLU(inplace=True)) + cls_branch.append(Linear(self.embed_dims, self.cls_out_channels)) + fc_cls = nn.Sequential(*cls_branch) + + reg_branch = [] + for _ in range(self.num_reg_fcs): + reg_branch.append(Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(Linear(self.embed_dims, self.code_size)) + reg_branch = nn.Sequential(*reg_branch) + + def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + # last reg_branch is used to generate proposal from + # encode feature map when as_two_stage is True. + num_pred = (self.transformer.decoder.num_layers + 1) if \ + self.as_two_stage else self.transformer.decoder.num_layers + + if self.with_box_refine: + self.cls_branches = _get_clones(fc_cls, num_pred) + self.reg_branches = _get_clones(reg_branch, num_pred) + else: + self.cls_branches = nn.ModuleList( + [fc_cls for _ in range(num_pred)]) + self.reg_branches = nn.ModuleList( + [reg_branch for _ in range(num_pred)]) + + if not self.as_two_stage: + self.bev_embedding = nn.Embedding( + self.bev_h * self.bev_w * self.bev_z, self.embed_dims) + self.query_embedding = nn.Embedding(self.num_query, + self.embed_dims * 2) + + def init_weights(self): + """Initialize weights of the DeformDETR head.""" + self.transformer.init_weights() + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + for m in self.cls_branches: + nn.init.constant_(m[-1].bias, bias_init) + + @auto_fp16(apply_to=('mlvl_feats')) + def forward(self, mlvl_feats, img_metas, prev_bev=None, only_bev=False): + """Forward function. + Args: + mlvl_feats (tuple[Tensor]): Features from the upstream + network, each is a 5D-tensor with shape + (B, N, C, H, W). + prev_bev: previous bev featues + only_bev: only compute BEV features with encoder. + Returns: + all_cls_scores (Tensor): Outputs from the classification head, \ + shape [nb_dec, bs, num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression \ + head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \ + Shape [nb_dec, bs, num_query, 9]. + """ + + bs, num_cam, _, _, _ = mlvl_feats[0].shape + dtype = mlvl_feats[0].dtype + object_query_embeds = self.query_embedding.weight.to(dtype) + bev_queries = self.bev_embedding.weight.to(dtype) + + bev_mask = torch.zeros((bs, self.bev_h, self.bev_w, self.bev_z), device=bev_queries.device, dtype=dtype) + bev_pos = self.positional_encoding(bev_mask).to(dtype) + + if only_bev: + outputs = self.transformer( + mlvl_feats, + bev_queries, + object_query_embeds, + self.bev_h, + self.bev_w, + self.bev_z, + grid_length=(self.real_h / self.bev_h, + self.real_w / self.bev_w), + bev_pos=bev_pos, + reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 + cls_branches=self.cls_branches if self.as_two_stage else None, + img_metas=img_metas, + prev_bev=prev_bev, + ) + bev_feat, bev_embed, hs, init_reference, inter_references, occupancy, occupancy_det = outputs + return bev_feat, bev_embed + + outputs = self.transformer( + mlvl_feats, + bev_queries, + object_query_embeds, + self.bev_h, + self.bev_w, + self.bev_z, + grid_length=(self.real_h / self.bev_h, + self.real_w / self.bev_w), + bev_pos=bev_pos, + reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 + cls_branches=self.cls_branches if self.as_two_stage else None, + img_metas=img_metas, + prev_bev=prev_bev + ) + + bev_feat, bev_embed, hs, init_reference, inter_references, occupancy, occupancy_det = outputs + hs = hs.permute(0, 2, 1, 3) + outputs_classes = [] + outputs_coords = [] + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.cls_branches[lvl](hs[lvl]) + tmp = self.reg_branches[lvl](hs[lvl]) + + # TODO: check the shape of reference + # assert reference.shape[-1] == 3 + tmp[..., 0:2] += reference[..., 0:2] + tmp[..., 0:2] = tmp[..., 0:2].sigmoid() + tmp[..., 4:5] += reference[..., 2:3] + tmp[..., 4:5] = tmp[..., 4:5].sigmoid() + tmp[..., 0:1] = (tmp[..., 0:1] * (self.pc_range[3] - + self.pc_range[0]) + self.pc_range[0]) + tmp[..., 1:2] = (tmp[..., 1:2] * (self.pc_range[4] - + self.pc_range[1]) + self.pc_range[1]) + tmp[..., 4:5] = (tmp[..., 4:5] * (self.pc_range[5] - + self.pc_range[2]) + self.pc_range[2]) + + # TODO: check if using sigmoid + outputs_coord = tmp + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + outputs_classes = torch.stack(outputs_classes) + outputs_coords = torch.stack(outputs_coords) + + outs = { + 'bev_feat': bev_feat, + 'bev_embed': bev_embed, + 'all_cls_scores': outputs_classes, + 'all_bbox_preds': outputs_coords, + 'enc_cls_scores': None, + 'enc_bbox_preds': None, + 'occupancy': occupancy, + 'occupancy_det':occupancy_det, + } + + return outs + + def _get_target_single(self, + cls_score, + bbox_pred, + gt_labels, + gt_bboxes, + gt_bboxes_ignore=None): + """"Compute regression and classification targets for one image. + Outputs from a single decoder layer of a single feature level are used. + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_query, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_query, 4]. + gt_bboxes (Tensor): Ground truth bboxes for one image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (Tensor): Ground truth class indices for one image + with shape (num_gts, ). + gt_bboxes_ignore (Tensor, optional): Bounding boxes + which can be ignored. Default None. + Returns: + tuple[Tensor]: a tuple containing the following for one image. + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + + num_bboxes = bbox_pred.size(0) + # assigner and sampler + gt_c = gt_bboxes.shape[-1] + + assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, + gt_labels) + + sampling_result = self.sampler.sample(assign_result, bbox_pred, + gt_bboxes) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label targets + labels = gt_bboxes.new_full((num_bboxes,), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_bboxes) + + # bbox targets + bbox_targets = torch.zeros_like(bbox_pred)[..., :gt_c] + bbox_weights = torch.zeros_like(bbox_pred) + bbox_weights[pos_inds] = 1.0 + + # DETR + bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes + return (labels, label_weights, bbox_targets, bbox_weights, + pos_inds, neg_inds) + + def get_targets(self, + cls_scores_list, + bbox_preds_list, + gt_bboxes_list, + gt_labels_list, + gt_bboxes_ignore_list=None): + """"Compute regression and classification targets for a batch image. + Outputs from a single decoder layer of a single feature level are used. + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image with shape [num_query, + cls_out_channels]. + bbox_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + Returns: + tuple: a tuple containing the following targets. + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all \ + images. + - bbox_targets_list (list[Tensor]): BBox targets for all \ + images. + - bbox_weights_list (list[Tensor]): BBox weights for all \ + images. + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + """ + + (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply( + self._get_target_single, cls_scores_list, bbox_preds_list, + gt_labels_list, gt_bboxes_list) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg) + + def loss_single(self, + cls_scores, + bbox_preds, + gt_bboxes_list, + gt_labels_list, + gt_bboxes_ignore_list=None): + """"Loss function for outputs from a single decoder layer of a single + feature level. + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images. Shape [bs, num_query, cls_out_channels]. + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape [bs, num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + Returns: + dict[str, Tensor]: A dictionary of loss components for outputs from + a single decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, + gt_bboxes_list, gt_labels_list) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + labels = torch.cat(labels_list, 0) + label_weights = torch.cat(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + + # classification loss + cls_scores = cls_scores.reshape(-1, self.cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean(cls_scores.new_tensor([cls_avg_factor])) + + cls_avg_factor = max(cls_avg_factor, 1) + + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt boxes accross all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # regression L1 loss + bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1)) + normalized_bbox_targets = normalize_bbox(bbox_targets, self.pc_range) + isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1) + bbox_weights = bbox_weights * self.code_weights + + loss_bbox = self.loss_bbox( + bbox_preds[isnotnan, :10], normalized_bbox_targets[isnotnan, :10], bbox_weights[isnotnan, :10], + avg_factor=num_total_pos) + if digit_version(TORCH_VERSION) >= digit_version('1.8'): + loss_cls = torch.nan_to_num(loss_cls) + loss_bbox = torch.nan_to_num(loss_bbox) + return loss_cls, loss_bbox + + def get_occupancy_det_label(self,voxel_coors_det, voxel_label_det, occupancy_det_label): + + voxel_coors_det[:, 0] = voxel_coors_det[:, 0].clip(min=0, max=self.bev_z-1) + voxel_coors_det[:, 1] = voxel_coors_det[:, 1].clip(min=0, max=self.bev_h-1) + voxel_coors_det[:, 2] = voxel_coors_det[:, 2].clip(min=0, max=self.bev_w-1) + + det_label_binary = ((voxel_label_det>=1) & (voxel_label_det<=10)) + det_label = det_label_binary.long() + occupancy_det_label[0, voxel_coors_det[:, 0], voxel_coors_det[:, 1], voxel_coors_det[:, 2]] = det_label + return occupancy_det_label + + def get_det_loss(self,voxel_label_det,occupancy_det_label,occupancy_det_pred): + + num_total_pos_det = len(voxel_label_det) + + num_total_neg_det = len(occupancy_det_label) - num_total_pos_det + avg_factor_det = num_total_pos_det * 1.0 + num_total_neg_det * self.bg_weight + if self.sync_cls_avg_factor: + avg_factor_det = reduce_mean( + occupancy_det_pred.new_tensor([avg_factor_det])) + avg_factor_det = max(avg_factor_det, 1) + + losses_det = self.lidar_det_loss(occupancy_det_pred, occupancy_det_label, avg_factor=avg_factor_det) + return losses_det + + @force_fp32(apply_to=('preds_dicts')) + def loss(self, + gt_bboxes_list, + gt_labels_list, + pts_sem, + preds_dicts, + dense_occupancy=None, + gt_bboxes_ignore=None, + img_metas=None): + """"Loss function. + Args: + + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + preds_dicts: + all_cls_scores (Tensor): Classification score of all + decoder layers, has shape + [nb_dec, bs, num_query, cls_out_channels]. + all_bbox_preds (Tensor): Sigmoid regression + outputs of all decode layers. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and shape + [nb_dec, bs, num_query, 4]. + enc_cls_scores (Tensor): Classification scores of + points on encode feature map , has shape + (N, h*w, num_classes). Only be passed when as_two_stage is + True, otherwise is None. + enc_bbox_preds (Tensor): Regression results of each points + on the encode feature map, has shape (N, h*w, 4). Only be + passed when as_two_stage is True, otherwise is None. + gt_bboxes_ignore (list[Tensor], optional): Bounding boxes + which can be ignored for each image. Default None. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + # Extract the first three columns from pts_sem + pts = pts_sem[:, :3] + + # Extract the fourth column from pts_sem + pts_semantic_mask = pts_sem[:, 3:4] + + # If dense_occupancy is None, perform voxelization and label voxelization + if dense_occupancy is None: + _, voxel_coors, voxel_label = self.voxelize(pts, self.pc_range, self.voxel_lidar, pts_semantic_mask) + + # Perform voxelization and label voxelization for detection + _, voxel_coors_det, voxel_label_det = self.voxelize(pts, self.pc_range, self.voxel_det, pts_semantic_mask) + + all_cls_scores = preds_dicts['all_cls_scores'] + all_bbox_preds = preds_dicts['all_bbox_preds'] + enc_cls_scores = preds_dicts['enc_cls_scores'] + enc_bbox_preds = preds_dicts['enc_bbox_preds'] + occupancy = preds_dicts['occupancy'] + occupancy_det = preds_dicts['occupancy_det'] + + occupancy_pred = occupancy.squeeze(0) + occupancy_det_pred = occupancy_det[0].squeeze(0) + + cls_num, occ_z, occ_h, occ_w = occupancy_pred.shape + if dense_occupancy is None: + occupancy_label = torch.full((1, occ_z, occ_h, occ_w), cls_num, device=occupancy_pred.device, dtype=torch.long) + else: + occupancy_label = (torch.zeros(1, occ_z, occ_h, occ_w)).to(occupancy_pred.device).long() + + occupancy_det_label = (torch.ones(1, self.bev_z, self.bev_h, self.bev_w) * 2).to(occupancy_det_pred.device).long() + + if dense_occupancy is None: + voxel_coors[:, 0] = voxel_coors[:, 0].clip(min=0, max=occ_z-1) + voxel_coors[:, 1] = voxel_coors[:, 1].clip(min=0, max=occ_h-1) + voxel_coors[:, 2] = voxel_coors[:, 2].clip(min=0, max=occ_w-1) + occupancy_label[0, voxel_coors[:, 0], voxel_coors[:, 1], voxel_coors[:, 2]] = voxel_label + else: + dense_occupancy = dense_occupancy.long().squeeze(0) + occupancy_label[0, dense_occupancy[:, 0], dense_occupancy[:, 1], dense_occupancy[:, 2]] = dense_occupancy[:, 3] + + occupancy_det_label = self.get_occupancy_det_label(voxel_coors_det, voxel_label_det, occupancy_det_label) + + losses_seg_aux = self.lidar_seg_aux_loss(occupancy_pred.unsqueeze(0), occupancy_label) + + occupancy_det_label = occupancy_det_label.reshape(-1) + occupancy_label = occupancy_label.reshape(-1) + + # assert occupancy_label.max()<=cls_num and occupancy_label.min()>=0 + occupancy_pred = occupancy_pred.reshape(cls_num,-1).permute(1,0) + occupancy_det_pred = occupancy_det_pred.reshape(2,-1).permute(1,0) + + num_dec_layers = len(all_cls_scores) + device = gt_labels_list[0].device + + gt_bboxes_list = [torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), + dim=1).to(device) for gt_bboxes in gt_bboxes_list] + + all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + + losses_cls, losses_bbox = multi_apply( + self.loss_single, all_cls_scores, all_bbox_preds, + all_gt_bboxes_list, all_gt_labels_list) + + loss_dict = dict() + + # Lidar seg loss + if dense_occupancy is None: + num_total_pos = len(voxel_label) + else: + num_total_pos = len(dense_occupancy) + num_total_neg = len(occupancy_label) - num_total_pos + avg_factor = num_total_pos * 1.0 + num_total_neg * self.bg_weight + if self.sync_cls_avg_factor: + avg_factor = reduce_mean( + occupancy_pred.new_tensor([avg_factor])) + avg_factor = max(avg_factor, 1) + + losses_seg = self.lidar_seg_loss(occupancy_pred, occupancy_label, avg_factor=avg_factor) + + # Lidar det loss + losses_det = self.get_det_loss(voxel_label_det, occupancy_det_label, occupancy_det_pred) + + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + binary_labels_list = [ + torch.zeros_like(gt_labels_list[i]) + for i in range(len(all_gt_labels_list)) + ] + enc_loss_cls, enc_losses_bbox = \ + self.loss_single(enc_cls_scores, enc_bbox_preds, + gt_bboxes_list, binary_labels_list) + loss_dict['enc_loss_cls'] = enc_loss_cls + loss_dict['enc_loss_bbox'] = enc_losses_bbox + + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_bbox'] = losses_bbox[-1] + loss_dict['loss_seg'] = losses_seg + loss_dict['loss_det'] = losses_det + loss_dict['loss_seg_aux'] = losses_seg_aux + + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i + num_dec_layer += 1 + + return loss_dict + + def voxelize(self, points, pc_range, voxel_size, pts_semantic_mask=None): + """ + Input: + points [N, 3], (x, y, z) + point_cloud_range [6], [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0], (-x, -y, -z, x, y, z) + voxelization_size [3], e.g. [0.256, 0.256, 0.125] + + Output: + coors [N,4], (0, z, y, x) + unq_coors [M,4], (0, z, y, x) + + """ + + coors = torch.div(points[:, :3] - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').to(torch.int32) + + unq_coors, unq_inv = custom_unique_n3(coors, return_inverse=True, return_counts=False, dim=0) + + if pts_semantic_mask is not None: + with torch.no_grad(): + voxel_label_my, _ = mx_driving.scatter_max(pts_semantic_mask, unq_inv.to(torch.int32)) + return coors[:, [2, 1, 0]].long(), unq_coors.long(), voxel_label_my.squeeze(-1).long() + return coors[:, [2, 1, 0]].long(), unq_coors.long() + + @force_fp32(apply_to=('preds_dicts')) + def get_bboxes(self, preds_dicts, img_metas, rescale=False): + """Generate bboxes from bbox head predictions. + Args: + preds_dicts (tuple[list[dict]]): Prediction results. + img_metas (list[dict]): Point cloud and image's meta info. + Returns: + list[dict]: Decoded bbox, scores and labels after nms. + """ + + preds_dicts = self.bbox_coder.decode(preds_dicts) + + num_samples = len(preds_dicts) + ret_list = [] + for i in range(num_samples): + preds = preds_dicts[i] + bboxes = preds['bboxes'] + + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 + + code_size = bboxes.shape[-1] + bboxes = img_metas[i]['box_type_3d'](bboxes, code_size) + scores = preds['scores'] + labels = preds['labels'] + + ret_list.append([bboxes, scores, labels]) + + return ret_list + + def decode_lidar_seg(self, points, occupancy): + + pts_coors, _ = self.voxelize(points, self.pc_range, self.voxel_lidar) + + # clip out-ranged points + z_max = int((self.pc_range[5] - self.pc_range[2]) / self.voxel_lidar[2]) - 1 + y_max = int((self.pc_range[4] - self.pc_range[1]) / self.voxel_lidar[1]) - 1 + x_max = int((self.pc_range[3] - self.pc_range[0]) / self.voxel_lidar[0]) - 1 + + pts_coors[:, 0] = pts_coors[:, 0].clip(min=0, max=z_max) + pts_coors[:, 1] = pts_coors[:, 1].clip(min=0, max=y_max) + pts_coors[:, 2] = pts_coors[:, 2].clip(min=0, max=x_max) + + pts_pred = occupancy[:, :, pts_coors[:, 0], pts_coors[:, 1], pts_coors[:, 2]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu().numpy() + + return pts_pred + + def decode_lidar_seg_hr(self,points,occupancy): + + out_h = 512 + out_w = 512 + out_z = 160 + + self.voxel_lidar = [102.4/out_h, 102.4/out_w, 8/out_z] + + pts_coors, _ = self.voxelize(points, self.pc_range, self.voxel_lidar) + + # clip out-ranged points + z_max = int((self.pc_range[5] - self.pc_range[2]) / self.voxel_lidar[2]) - 1 + y_max = int((self.pc_range[4] - self.pc_range[1]) / self.voxel_lidar[1]) - 1 + x_max = int((self.pc_range[3] - self.pc_range[0]) / self.voxel_lidar[0]) - 1 + pts_coors[:, 0] = pts_coors[:, 0].clip(min=0, max=z_max) + pts_coors[:, 1] = pts_coors[:, 1].clip(min=0, max=y_max) + pts_coors[:, 2] = pts_coors[:, 2].clip(min=0, max=x_max) + + + new_h = torch.linspace(-1, 1, out_h).view(1, out_h, 1).expand(out_z, out_h, out_w) + new_w = torch.linspace(-1, 1, out_w).view(1, 1, out_w).expand(out_z, out_h, out_w) + new_z = torch.linspace(-1, 1, out_z).view(out_z, 1, 1).expand(out_z, out_h, out_w) + + grid = torch.cat((new_w.unsqueeze(3), new_h.unsqueeze(3), new_z.unsqueeze(3)), dim=-1) + + grid = grid.unsqueeze(0).to(occupancy.device) + + torch.npu.set_compile_mode(jit_compile=True) + out_logit = F.grid_sample(occupancy, grid=grid) + torch.npu.set_compile_mode(jit_compile=False) + + pts_pred = out_logit[:, :, pts_coors[:, 0], pts_coors[:, 1], pts_coors[:, 2]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu().numpy() + return pts_pred + + def decode_lidar_seg_dense(self, dense, occupancy): + dense = dense.long() + pts_pred = occupancy[:, :, dense[0, :, 0], dense[0, :, 1], dense[0, :, 2]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu().numpy() + return pts_pred + + if hasattr(panoseg_occ_head_module, 'PanoSegOccHead'): + panoseg_occ_head_module.PanoSegOccHead = PanoSegOccHead + else: + raise ValueError('PanoSegOccHead attr not found') \ No newline at end of file diff --git a/model_examples/PanoOcc/migrate_to_ascend/patch.py b/model_examples/PanoOcc/migrate_to_ascend/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..9b0299bf02f8ca4292e96c264e94c511456bc58d --- /dev/null +++ b/model_examples/PanoOcc/migrate_to_ascend/patch.py @@ -0,0 +1,856 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import collections +import sys +import os +import types +import random +from types import ModuleType +from typing import Dict + +import numpy as np +import torch +import torch.nn.functional as F +import torch_npu +import mmcv +import mmcv.runner + +import mx_driving +from mx_driving.patcher import PatcherBuilder, Patch +from mx_driving.patcher.distribute import ddp, ddp_forward +from mx_driving.patcher.functions import stream +from mx_driving.patcher.tensor import index, batch_matmul +from mx_driving.patcher.numpy import numpy_type +from mx_driving.patcher.mmcv import mdc, msda, dc +from mx_driving.patcher.optimizer import optimizer_hooks +from mx_driving.patcher.mmdet import resnet_add_relu, resnet_maxpool +from mx_driving.patcher.mmdet3d import nuscenes_dataset +from mx_driving import multi_scale_deformable_attn + +from migrate_to_ascend.panoseg_occ_head_patch import panoseg_occ_head_patch + + +def occ_temporal_attention_patch(occ_temporal_attention_module: ModuleType, options: Dict): + + def forward(self, query, key=None, value=None, identity=None, query_pos=None, key_padding_mask=None, + reference_points=None, spatial_shapes=None, level_start_index=None, flag='decoder', **kwargs): + if value is None: + assert self.batch_first + bs, len_bev, c = query.shape + value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c) + + # value = torch.cat([query, query], 0) + + if identity is None: + identity = query + if query_pos is not None: + query = query + query_pos + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + bs, num_query, embed_dims = query.shape + _, num_value, _ = value.shape + + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1] * spatial_shapes[:, 2]).sum() == num_value + assert self.num_bev_queue == 2 + + query = torch.cat([value[:bs], query], -1) + value = self.value_proj(value) + + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + + value = value.reshape(bs*self.num_bev_queue, + num_value, self.num_heads, -1) + + sampling_offsets = self.sampling_offsets(query) + sampling_offsets = sampling_offsets.view( + bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels, self.num_points, 2) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels * self.num_points) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view(bs, num_query, + self.num_heads, + self.num_bev_queue, + self.num_levels, + self.num_points) + + attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5)\ + .reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous() + sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6)\ + .reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2) + + # all points in pillar have the same xy + z_num = sampling_offsets.shape[1] //reference_points.shape[1] + bsq,bev_num,level,xy = reference_points.shape + reference_points = reference_points.unsqueeze(2).expand(bsq,bev_num,z_num,level,xy).reshape(bsq,-1,level,xy) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets \ + / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.num_points \ + * reference_points[:, :, None, :, None, 2:] \ + * 0.5 + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2 or 4, but get {reference_points.shape[-1]} instead.') + + sampling_locations = sampling_locations.contiguous() + if torch.cuda.is_available() and value.is_cuda: + output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, + sampling_locations, attention_weights) + else: + raise ValueError("CUDA/CANN unavailable?") + + # output shape (bs*num_bev_queue, num_query, embed_dims) + # (bs*num_bev_queue, num_query, embed_dims)-> (num_query, embed_dims, bs*num_bev_queue) + output = output.permute(1, 2, 0) + + # fuse history value and current value + # (num_query, embed_dims, bs*num_bev_queue)-> (num_query, embed_dims, bs, num_bev_queue) + output = output.view(num_query, embed_dims, bs, self.num_bev_queue) + output = output.mean(-1) + + # (num_query, embed_dims, bs)-> (bs, num_query, embed_dims) + output = output.permute(2, 0, 1) + + output = self.output_proj(output) + + if not self.batch_first: + output = output.permute(1, 0, 2) + + return self.dropout(output) + identity + + if hasattr(occ_temporal_attention_module, 'OccTemporalAttention'): + occ_temporal_attention_module.OccTemporalAttention.forward = forward + else: + raise ValueError('OccTemporalAttention attr not found') + + +def panoseg_transformer_occ_patch(panoseg_transformer_occ_module: ModuleType, options: Dict): + import numpy as np + import torch.nn as nn + + def align_prev_bev(self, prev_bev, bev_h, bev_w, bev_z, **kwargs): + if prev_bev is not None: + pc_range = self.cam_encoder.pc_range + ref_y, ref_x, ref_z = torch.meshgrid( + torch.linspace(0.5, bev_h - 0.5, bev_h, dtype=prev_bev.dtype, device=prev_bev.device), + torch.linspace(0.5, bev_w - 0.5, bev_w, dtype=prev_bev.dtype, device=prev_bev.device), + torch.linspace(0.5, bev_z - 0.5, bev_z, dtype=prev_bev.dtype, device=prev_bev.device), + ) + ref_y = ref_y / bev_h + ref_x = ref_x / bev_w + ref_z = ref_z / bev_z + + grid = torch.stack( + (ref_x, + ref_y, + ref_z, + ref_x.new_ones(ref_x.shape)), dim=-1) + + min_x, min_y, min_z, max_x, max_y, max_z = pc_range + grid[..., 0] = grid[..., 0] * (max_x - min_x) + min_x + grid[..., 1] = grid[..., 1] * (max_y - min_y) + min_y + grid[..., 2] = grid[..., 2] * (max_z - min_z) + min_z + grid = grid.reshape(-1, 4) + + bs = prev_bev.shape[0] + len_queue = prev_bev.shape[1] + assert bs == 1 + for i in range(bs): + assert len_queue + 1 == len(kwargs["img_metas"][i]["ego2global_transform_lst"]) + lidar_to_ego = kwargs['img_metas'][i]['lidar2ego_transformation'] + curr_ego_to_global = kwargs['img_metas'][i]['ego2global_transform_lst'][-1] + + curr_grid_in_prev_frame_lst = [] + for j in range(len_queue): + prev_ego_to_global = kwargs['img_metas'][i]['ego2global_transform_lst'][j] + prev_lidar_to_curr_lidar = np.linalg.inv(lidar_to_ego) @ np.linalg.inv(curr_ego_to_global) @ prev_ego_to_global @ lidar_to_ego + curr_lidar_to_prev_lidar = np.linalg.inv(prev_lidar_to_curr_lidar) + curr_lidar_to_prev_lidar = grid.new_tensor(curr_lidar_to_prev_lidar) + + # fix z + curr_lidar_to_prev_lidar[2,3] = curr_lidar_to_prev_lidar[2,3]*0 + + curr_grid_in_prev_frame = torch.matmul(curr_lidar_to_prev_lidar, grid.T).T.reshape(bev_h, bev_w, bev_z, -1)[..., :3] + curr_grid_in_prev_frame[..., 0] = (curr_grid_in_prev_frame[..., 0] - min_x) / (max_x - min_x) + curr_grid_in_prev_frame[..., 1] = (curr_grid_in_prev_frame[..., 1] - min_y) / (max_y - min_y) + curr_grid_in_prev_frame[..., 2] = (curr_grid_in_prev_frame[..., 2] - min_z) / (max_z - min_z) + curr_grid_in_prev_frame = curr_grid_in_prev_frame * 2.0 - 1.0 + curr_grid_in_prev_frame_lst.append(curr_grid_in_prev_frame) + + curr_grid_in_prev_frame = torch.stack(curr_grid_in_prev_frame_lst, dim=0) + + + torch.npu.set_compile_mode(jit_compile=True) # +++ + prev_bev_warp_to_curr_frame = nn.functional.grid_sample( + prev_bev[i].permute(0, 1, 4, 2, 3), # [bs, dim, z, h, w] + curr_grid_in_prev_frame.permute(0, 3, 1, 2, 4), # [bs, z, h, w, 3] + align_corners=False) + torch.npu.set_compile_mode(jit_compile=False) # +++ + + prev_bev = prev_bev_warp_to_curr_frame.permute(0, 1, 3, 4, 2).unsqueeze(0) # add bs dim, [bs, dim, h, w, z] + + return prev_bev + + + if hasattr(panoseg_transformer_occ_module, 'PanoSegOccTransformer'): + panoseg_transformer_occ_module.PanoSegOccTransformer.align_prev_bev = align_prev_bev + else: + raise ValueError('PanoSegOccTransformer attr not found') + + +def temporal_self_attention_patch(temporal_self_attention_module: ModuleType, options: Dict): + def forward(self, query, key=None, value=None, identity=None, query_pos=None, key_padding_mask=None, + reference_points=None, spatial_shapes=None, level_start_index=None, flag='decoder', **kwargs): + if value is None: + assert self.batch_first + bs, len_bev, c = query.shape + value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c) + + # value = torch.cat([query, query], 0) + + if identity is None: + identity = query + if query_pos is not None: + query = query + query_pos + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + bs, num_query, embed_dims = query.shape + _, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + assert self.num_bev_queue == 2 + + query = torch.cat([value[:bs], query], -1) + value = self.value_proj(value) + + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + + value = value.reshape(bs*self.num_bev_queue, + num_value, self.num_heads, -1) + + sampling_offsets = self.sampling_offsets(query) + sampling_offsets = sampling_offsets.view( + bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels, self.num_points, 2) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels * self.num_points) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view(bs, num_query, + self.num_heads, + self.num_bev_queue, + self.num_levels, + self.num_points) + + attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5)\ + .reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous() + sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6)\ + .reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets \ + / offset_normalizer[None, None, None, :, None, :] + + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.num_points \ + * reference_points[:, :, None, :, None, 2:] \ + * 0.5 + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2 or 4, but get {reference_points.shape[-1]} instead.') + if torch.cuda.is_available() and value.is_cuda: + output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, + sampling_locations, attention_weights) + else: + raise ValueError("CUDA/CANN unavailable?") + + # output shape (bs*num_bev_queue, num_query, embed_dims) + # (bs*num_bev_queue, num_query, embed_dims)-> (num_query, embed_dims, bs*num_bev_queue) + output = output.permute(1, 2, 0) + + # fuse history value and current value + # (num_query, embed_dims, bs*num_bev_queue)-> (num_query, embed_dims, bs, num_bev_queue) + output = output.view(num_query, embed_dims, bs, self.num_bev_queue) + output = output.mean(-1) + + # (num_query, embed_dims, bs)-> (bs, num_query, embed_dims) + output = output.permute(2, 0, 1) + + output = self.output_proj(output) + + if not self.batch_first: + output = output.permute(1, 0, 2) + + return self.dropout(output) + identity + + if hasattr(temporal_self_attention_module, 'TemporalSelfAttention'): + temporal_self_attention_module.TemporalSelfAttention.forward = forward + else: + raise ValueError('TemporalSelfAttention attr not found') + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + + +def mmdet3d_dataset_builder_patch(builder_module: ModuleType, options: Dict): + from functools import partial + from mmcv.runner import get_dist_info + from mmcv.parallel import collate + from mmdet.datasets.samplers import GroupSampler + from projects.mmdet3d_plugin.datasets.samplers.sampler import build_sampler + from torch.utils.data import DataLoader + + + # Pin Memory + def build_dataloader(dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + shuffler_sampler=None, + nonshuffler_sampler=None, + **kwargs): + """Build PyTorch DataLoader. + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + Args: + dataset (Dataset): A PyTorch dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed training. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + kwargs: any keyword argument to be used to initialize DataLoader + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + if dist: + # DistributedGroupSampler will definitely shuffle the data to satisfy + # that images on each GPU are in the same group + if shuffle: + sampler = build_sampler(shuffler_sampler if shuffler_sampler is not None else dict(type='DistributedGroupSampler'), + dict( + dataset=dataset, + samples_per_gpu=samples_per_gpu, + num_replicas=world_size, + rank=rank, + seed=seed) + ) + + else: + sampler = build_sampler(nonshuffler_sampler if nonshuffler_sampler is not None else dict(type='DistributedSampler'), + dict( + dataset=dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=seed) + ) + + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + # assert False, 'not support in bevformer' + print('WARNING!!!!, Only can be used for obtain inference speed!!!!') + sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=True, + worker_init_fn=init_fn, + **kwargs) + + return data_loader + + if hasattr(builder_module, 'build_dataloader'): + builder_module.build_dataloader = build_dataloader + else: + raise ValueError('build_dataloader attr not found') + + +def mmdet3d_dataset_compose_patch(compose_module: ModuleType, options: Dict): + + from mmcv.utils import build_from_cfg + from mmdet.datasets.builder import PIPELINES + from mmdet3d.datasets.builder import PIPELINES as PIPELINES_3d + + def __init__(self, transforms): + assert isinstance(transforms, collections.abc.Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + if transform["type"] not in PIPELINES: + transform = build_from_cfg(transform, PIPELINES_3d) + else: + transform = build_from_cfg(transform, PIPELINES) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict') + + if hasattr(compose_module, 'CustomCompose'): + compose_module.CustomCompose.__init__ = __init__ + else: + raise ValueError('CustomCompose attr not found') + + +def decoder_patch(decoder_module: ModuleType, options: Dict): + + from mx_driving import multi_scale_deformable_attn + + def forward(self, + query, + key=None, + value=None, + identity=None, + query_pos=None, + key_padding_mask=None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + flag='decoder', + **kwargs): + + if value is None: + value = query + + if identity is None: + identity = query + if query_pos is not None: + query = query + query_pos + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, _ = query.shape + bs, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + value = value.view(bs, num_value, self.num_heads, -1) + + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view(bs, num_query, + self.num_heads, + self.num_levels, + self.num_points) + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets \ + / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.num_points \ + * reference_points[:, :, None, :, None, 2:] \ + * 0.5 + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2 or 4, but get {reference_points.shape[-1]} instead.') + if torch.cuda.is_available() and value.is_cuda: + output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, + sampling_locations, attention_weights) + else: + output = multi_scale_deformable_attn_pytorch( + value, spatial_shapes, sampling_locations, attention_weights) + + output = self.output_proj(output) + + if not self.batch_first: + # (num_query, bs ,embed_dims) + output = output.permute(1, 0, 2) + + return self.dropout(output) + identity + + if hasattr(decoder_module, 'CustomMSDeformableAttention'): + decoder_module.CustomMSDeformableAttention.forward = forward + else: + raise ValueError('CustomMSDeformableAttention attr not found') + + +def spatial_cross_attention_patch(spatial_cross_attention_module: ModuleType, options: Dict): + + import math + import warnings + from mx_driving import multi_scale_deformable_attn + from mmcv.runner import force_fp32 + + indexes_global = options['indexes_global'] + max_len_global = options['max_len_global'] + bev_mask_id_global = options['bev_mask_id_global'] + count_global = options['count_global'] + + @force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points_cam')) + def sca_forward(self, + query, + key, + value, + residual=None, + query_pos=None, + key_padding_mask=None, + reference_points=None, + spatial_shapes=None, + reference_points_cam=None, + bev_mask=None, + level_start_index=None, + flag='encoder', + **kwargs): + if key is None: + key = query + if value is None: + value = key + + if residual is None: + inp_residual = query + slots = torch.zeros_like(query) + if query_pos is not None: + query = query + query_pos + + bs, num_query, _ = query.size() + # bevformer reference_points_cam shape: (num_cam,bs,h*w,num_points_in_pillar,2) + D = reference_points_cam.size(3) + indexes = [] + global indexes_global, max_len_global, bev_mask_id_global, count_global + bev_mask_id = id(bev_mask) + if bev_mask_id == bev_mask_id_global: + indexes = indexes_global + max_len = max_len_global + count = count_global + else: + count = torch.any(bev_mask, 3) + bev_mask_ = count.squeeze() + for i, mask_per_img in enumerate(bev_mask_): + index_query_per_img = mask_per_img.nonzero().squeeze(-1) + indexes.append(index_query_per_img) + + max_len = max([len(each) for each in indexes]) + count = count.permute(1, 2, 0).sum(-1) + count = torch.clamp(count, min=1.0) + count = count[..., None] + count_global = count + indexes_global = indexes + max_len_global = max_len + bev_mask_id_global = bev_mask_id + + # each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory. + queries_rebatch = query.new_zeros( + [bs, self.num_cams, max_len, self.embed_dims]) + reference_points_rebatch = reference_points_cam.new_zeros( + [bs, self.num_cams, max_len, D, 2]) + + for i, reference_points_per_img in enumerate(reference_points_cam): + index_query_per_img = indexes[i] + for j in range(bs): + queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img] + reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img] + + num_cams, l, bs, embed_dims = key.shape + + key = key.permute(2, 0, 1, 3).reshape( + bs * self.num_cams, l, self.embed_dims) + value = value.permute(2, 0, 1, 3).reshape( + bs * self.num_cams, l, self.embed_dims) + + queries = self.deformable_attention(query=queries_rebatch.view(bs * self.num_cams, max_len, self.embed_dims), key=key, value=value, + reference_points=reference_points_rebatch.view(bs * self.num_cams, max_len, D, 2), spatial_shapes=spatial_shapes, + level_start_index=level_start_index).view(bs, self.num_cams, max_len, self.embed_dims) + for j in range(bs): + for i, index_query_per_img in enumerate(indexes): + slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)] + + + slots = slots / count + slots = self.output_proj(slots) + + return self.dropout(slots) + inp_residual + + + def msda3d_forward(self, + query, + key=None, + value=None, + identity=None, + query_pos=None, + key_padding_mask=None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + **kwargs): + if value is None: + value = query + if identity is None: + identity = query + if query_pos is not None: + query = query + query_pos + + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, _ = query.shape + bs, num_value, _ = value.shape + # assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + value = value.view(bs, num_value, self.num_heads, -1) + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points) + + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view(bs, num_query, + self.num_heads, + self.num_levels, + self.num_points) + + if reference_points.shape[-1] == 2: + """ + For each BEV query, it owns `num_Z_anchors` in 3D space that having different heights. + After proejcting, each BEV query has `num_Z_anchors` reference points in each 2D image. + For each referent point, we sample `num_points` sampling points. + For `num_Z_anchors` reference points, it has overall `num_points * num_Z_anchors` sampling points. + """ + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + + bs, num_query, num_Z_anchors, xy = reference_points.shape + reference_points = reference_points[:, :, None, None, None, :, :] + sampling_offsets = sampling_offsets / \ + offset_normalizer[None, None, None, :, None, :] + bs, num_query, num_heads, num_levels, num_all_points, xy = sampling_offsets.shape + sampling_offsets = sampling_offsets.view( + bs, num_query, num_heads, num_levels, num_all_points // num_Z_anchors, num_Z_anchors, xy) + sampling_locations = reference_points + sampling_offsets + bs, num_query, num_heads, num_levels, num_points, num_Z_anchors, xy = sampling_locations.shape + # assert num_all_points == num_points * num_Z_anchors + + sampling_locations = sampling_locations.view( + bs, num_query, num_heads, num_levels, num_all_points, xy) + + elif reference_points.shape[-1] == 4: + assert False + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2 or 4, but get {reference_points.shape[-1]} instead.') + + if torch.cuda.is_available() and value.is_cuda: + output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, + sampling_locations, attention_weights) + else: + output = multi_scale_deformable_attn_pytorch( + value, spatial_shapes, sampling_locations, attention_weights) + if not self.batch_first: + output = output.permute(1, 0, 2) + + return output + + + if hasattr(spatial_cross_attention_module, 'SpatialCrossAttention'): + spatial_cross_attention_module.SpatialCrossAttention.forward = sca_forward + else: + raise ValueError('SpatialCrossAttention attr not found') + + if hasattr(spatial_cross_attention_module, 'MSDeformableAttention3D'): + spatial_cross_attention_module.MSDeformableAttention3D.forward = msda3d_forward + else: + raise ValueError('MSDeformableAttention3D attr not found') + + +def generate_patcher_builder(): + patcher_builder = ( + PatcherBuilder() + .add_module_patch("torch", Patch(index), Patch(batch_matmul)) + .add_module_patch("numpy", Patch(numpy_type)) + .add_module_patch('mmcv.parallel', Patch(ddp), Patch(stream), Patch(ddp_forward)) + .add_module_patch('mmcv.ops', Patch(mdc), Patch(msda), Patch(dc)) + .add_module_patch('mmcv.runner.hooks', Patch(optimizer_hooks)) + .add_module_patch("mmdet.models.backbones.resnet", Patch(resnet_add_relu), Patch(resnet_maxpool)) + .add_module_patch('projects.mmdet3d_plugin.bevformer.dense_heads.panoseg_occ_head', + Patch(panoseg_occ_head_patch)) + .add_module_patch('projects.mmdet3d_plugin.bevformer.modules.decoder', Patch(decoder_patch)) + .add_module_patch('projects.mmdet3d_plugin.bevformer.modules.occ_temporal_attention', + Patch(occ_temporal_attention_patch)) + .add_module_patch('projects.mmdet3d_plugin.bevformer.modules.panoseg_transformer_occ', + Patch(panoseg_transformer_occ_patch)) + .add_module_patch('projects.mmdet3d_plugin.bevformer.modules.temporal_self_attention', + Patch(temporal_self_attention_patch)) + .add_module_patch('projects.mmdet3d_plugin.bevformer.modules.spatial_cross_attention', + Patch(spatial_cross_attention_patch, { + 'indexes_global': 'None', + 'max_len_global': 'None', + 'bev_mask_id_global': '-1', + 'count_global': 'None' + })) + .add_module_patch('projects.mmdet3d_plugin.datasets.builder', Patch(mmdet3d_dataset_builder_patch)) + .add_module_patch('projects.mmdet3d_plugin.datasets.pipelines.compose', Patch(mmdet3d_dataset_compose_patch)) + ) + return patcher_builder + + +brake_flag = False +profile_flag = False + + +def set_brake_at_step(patcher_builder: PatcherBuilder, end_step: int = 1000): + if profile_flag == True: + raise RuntimeError('with_profiling has been set, brake and profiling are mutually exclusive') + patcher_builder.brake_at(end_step) + brake_flag = True + + +def set_profiling(patcher_builder: PatcherBuilder, profiling_path: str, profiling_level: int = 0): + if brake_flag == True: + raise RuntimeError('brake_at has been set, brake and profiling are mutually exclusive') + patcher_builder.with_profiling(profiling_path, profiling_level) + profile_flag = True + + +class MethodPatcher: + + @staticmethod + def nccl_to_hccl(runner: ModuleType): + module = importlib.import_module(runner) + + if hasattr(module, "dist_utils"): + mp = module.dist_utils.mp + _init_dist_pytorch = module.dist_utils._init_dist_pytorch + _init_dist_mpi = module.dist_utils._init_dist_mpi + _init_dist_slurm = module.dist_utils._init_dist_slurm + + def hccl_init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: + + # Replacement for using hccl as the backend + backend = 'hccl' + + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'mpi': + _init_dist_mpi(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + else: + raise ValueError('dist_utils attr not found') + + module.init_dist = hccl_init_dist + print("[Patch]: nccl replaced by hccl, runner.init_dist = ", module.init_dist.__name__) + + +class ConfigPatcher: + @staticmethod + def adamw_to_npu_fused_adam(runner: ModuleType): + module = importlib.import_module(runner) + copy = module.optimizer.builder.copy + build_optimizer_constructor = module.optimizer.builder.build_optimizer_constructor + + def build_optimizer(model, cfg: Dict): + optimizer_cfg = copy.deepcopy(cfg) + # change code + # use NpuFused optimizer + optimizer_cfg['type'] = 'NpuFused' + optimizer_cfg['type'] + constructor_type = optimizer_cfg.pop('constructor', + 'DefaultOptimizerConstructor') + paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) + optim_constructor = build_optimizer_constructor( + dict( + type=constructor_type, + optimizer_cfg=optimizer_cfg, + paramwise_cfg=paramwise_cfg)) + optimizer = optim_constructor(model) + return optimizer + + module.build_optimizer = build_optimizer + + +def _init(): + # block dependencies that are not used nor installed + sys.modules['mmdet3d.ops.scatter_v2'] = ModuleType('mmdet3d.ops.scatter_v2') + sys.modules['torch_scatter'] = ModuleType('torch_scatter') + + sys.modules['projects.mmdet3d_plugin.models.backbones.sam_modeling.image_encoder'] = \ + ModuleType('projects.mmdet3d_plugin.models.backbones.sam_modeling.image_encoder') + sys.modules['projects.mmdet3d_plugin.models.backbones.sam_modeling.image_encoder.ImageEncoderViT'] = \ + ModuleType('projects.mmdet3d_plugin.models.backbones.sam_modeling.image_encoder.ImageEncoderViT') + + sys.modules['projects.mmdet3d_plugin.models.backbones.internv2_impl16'] = \ + ModuleType('projects.mmdet3d_plugin.models.backbones.internv2_impl16') + sys.modules['projects.mmdet3d_plugin.models.backbones.internv2_impl16.InternV2Impl16'] = \ + ModuleType('projects.mmdet3d_plugin.models.backbones.internv2_impl16.InternV2Impl16') + + sys.modules['spconv'] = ModuleType('spconv') + sys.modules['spconv.pytorch'] = ModuleType('spconv.pytorch') + sys.modules['spconv.pytorch.SparseConvTensor'] = ModuleType('spconv.pytorch.SparseConvTensor') + sys.modules['spconv.pytorch.SparseSequential'] = ModuleType('spconv.pytorch.SparseSequential') + + sys.modules['ipdb'] = ModuleType('ipdb') + sys.modules['ipdb.set_trace'] = ModuleType('ipdb.set_trace') + + torch.npu.config.allow_internal_format = False + + MethodPatcher.nccl_to_hccl('mmcv.runner') + ConfigPatcher.adamw_to_npu_fused_adam('mmcv.runner') + + +''' +Initialize to execute method patcher +Takes place before their corresponding imports +''' +_init() \ No newline at end of file diff --git a/model_examples/PanoOcc/migrate_to_ascend/requirements.txt b/model_examples/PanoOcc/migrate_to_ascend/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a6148da007f94dd683c14cbb3dab55d93a2a220a --- /dev/null +++ b/model_examples/PanoOcc/migrate_to_ascend/requirements.txt @@ -0,0 +1,27 @@ +mmdet==2.24.0 + +mmsegmentation==0.30.0 +torchvision +ipython==8.12 + + +setuptools==56.1.0 +pyyaml +ninja + +tornado +scipy +psutil +ml-dtypes +cloudpickle +attrs +psutil +decorator +absl-py + +numpy==1.24.4 +torch==2.1.0 + +sympy +synr +protobuf \ No newline at end of file diff --git a/model_examples/PanoOcc/migrate_to_ascend/train.py b/model_examples/PanoOcc/migrate_to_ascend/train.py new file mode 100644 index 0000000000000000000000000000000000000000..ceef97b12a01de9f8ee6de311cd636c78d81e162 --- /dev/null +++ b/model_examples/PanoOcc/migrate_to_ascend/train.py @@ -0,0 +1,269 @@ +from __future__ import division + +import argparse +import copy +import mmcv +import os +import time +import torch +import warnings +from os import path as osp + +from mmcv import Config, DictAction +from mmcv.runner import get_dist_info, init_dist +from mmcv.utils import TORCH_VERSION, digit_version + +from mx_driving.patcher.mmcv import patch_mmcv_version +patch_mmcv_version('1.6.0') + +from mmdet import __version__ as mmdet_version +from mmdet3d import __version__ as mmdet3d_version + +from mmdet3d.datasets import build_dataset +from mmdet3d.models import build_model +from mmdet3d.utils import collect_env, get_root_logger +from mmdet.apis import set_random_seed +from mmseg import __version__ as mmseg_version + +import torch_npu +from torch_npu.contrib import transfer_to_npu + +from migrate_to_ascend.patch import generate_patcher_builder, set_brake_at_step, set_profiling +import mx_driving + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a detector') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume-from', help='the checkpoint file to resume from') + parser.add_argument( + '--no-validate', + action='store_true', + help='whether not to evaluate the checkpoint during training') + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument( + '--gpus', + type=int, + help='number of gpus to use ' + '(only applicable to non-distributed training)') + group_gpus.add_argument( + '--gpu-ids', + type=int, + nargs='+', + help='ids of gpus to use ' + '(only applicable to non-distributed training)') + parser.add_argument('--seed', type=int, default=0, help='random seed') + parser.add_argument( + '--deterministic', + action='store_true', + help='whether to set deterministic options for CUDNN backend.') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file (deprecate), ' + 'change to --cfg-options instead.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--autoscale-lr', + action='store_true', + help='automatically scale lr with the number of gpus') + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both specified, ' + '--options is deprecated in favor of --cfg-options') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # import modules from string list. + if cfg.get('custom_imports', None): + from mmcv.utils import import_modules_from_strings + import_modules_from_strings(**cfg['custom_imports']) + + # import modules from plguin/xx, registry will be updated + if hasattr(cfg, 'plugin'): + if cfg.plugin: + import importlib + if hasattr(cfg, 'plugin_dir'): + plugin_dir = cfg.plugin_dir + _module_dir = os.path.dirname(plugin_dir) + _module_dir = _module_dir.split('/') + _module_path = _module_dir[0] + + for m in _module_dir[1:]: + _module_path = _module_path + '.' + m + print(_module_path) + plg_lib = importlib.import_module(_module_path) + else: + # import dir is the dirpath for the config file + _module_dir = os.path.dirname(args.config) + _module_dir = _module_dir.split('/') + _module_path = _module_dir[0] + for m in _module_dir[1:]: + _module_path = _module_path + '.' + m + print(_module_path) + plg_lib = importlib.import_module(_module_path) + + from projects.mmdet3d_plugin.bevformer.apis.train import custom_train_model + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + # if args.resume_from is not None: + if args.resume_from is not None and osp.isfile(args.resume_from): + cfg.resume_from = args.resume_from + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids + else: + cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) + if digit_version(TORCH_VERSION) == digit_version('1.8.1') and cfg.optimizer['type'] == 'AdamW': + cfg.optimizer['type'] = 'AdamW2' # fix bug in Adamw + if args.autoscale_lr: + # apply the linear scaling rule + cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8 + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + # specify logger name, if we still use 'mmdet', the output info will be + # filtered and won't be saved in the log_file + # TODO: ugly workaround to judge whether we are training det or seg model + if cfg.model.type in ['EncoderDecoder3D']: + logger_name = 'mmseg' + else: + logger_name = 'mmdet' + logger = get_root_logger( + log_file=log_file, log_level=cfg.log_level, name=logger_name) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + meta['config'] = cfg.pretty_text + + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + if args.seed is not None: + logger.info(f'Set random seed to {args.seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(args.seed, deterministic=args.deterministic) + cfg.seed = args.seed + meta['seed'] = args.seed + meta['exp_name'] = osp.basename(args.config) + + model = build_model( + cfg.model, + train_cfg=cfg.get('train_cfg'), + test_cfg=cfg.get('test_cfg')) + model.init_weights() + + logger.info(f'Model:\n{model}') + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + # in case we use a dataset wrapper + if 'dataset' in cfg.data.train: + val_dataset.pipeline = cfg.data.train.dataset.pipeline + else: + val_dataset.pipeline = cfg.data.train.pipeline + # set test_mode=False here in deep copied config + # which do not affect AP/AR calculation later + # noqa + val_dataset.test_mode = False + datasets.append(build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmdet version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=mmdet_version, + mmseg_version=mmseg_version, + mmdet3d_version=mmdet3d_version, + config=cfg.pretty_text, + CLASSES=datasets[0].CLASSES, + PALETTE=datasets[0].PALETTE # for segmentors + if hasattr(datasets[0], 'PALETTE') else None) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + custom_train_model( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta) + + +if __name__ == '__main__': + + patcher_builder = generate_patcher_builder() + ''' + E.g. for setting brake step or collect profiling: + set_brake_at_step(patcher_builder=patcher_builder, end_step=10) + set_profiling(patcher_builder=patcher_builder, profiling_path='./profiling_level0', profiling_level=0) + ''' + with patcher_builder.build(): + main() diff --git a/model_examples/PanoOcc/test/train_8p_panoocc_base_4f_fp32.sh b/model_examples/PanoOcc/migrate_to_ascend/train_8p_panoocc_base_4f_fp32.sh similarity index 88% rename from model_examples/PanoOcc/test/train_8p_panoocc_base_4f_fp32.sh rename to model_examples/PanoOcc/migrate_to_ascend/train_8p_panoocc_base_4f_fp32.sh index 2eac2ca7db635b18cbcc3720ecfa96f572af3401..396b2ec5440f85b751ffbc102cd09c2a0141c765 100644 --- a/model_examples/PanoOcc/test/train_8p_panoocc_base_4f_fp32.sh +++ b/model_examples/PanoOcc/migrate_to_ascend/train_8p_panoocc_base_4f_fp32.sh @@ -4,11 +4,16 @@ NETWORK="PanoOcc_Base_4f" DEVICE_TYPE=$(uname -m) -WORLD_SIZE=8 + +WORLD_SIZE=2 # Number of NPUs/GPUs +#export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export ASCEND_RT_VISIBLE_DEVICES=6,7 + + BATCH_SIZE=1 TOTAL_EPOCHS=24 -# 获取传入的参数,重新赋值 TOTAL_EPOCHS +# 获取传入的参数,重新赋值 TOTAL_EPOCHS, 例:--epochs=3 for para in $* do if [[ $para == --epochs=* ]]; then @@ -21,7 +26,7 @@ CASE_NAME=${NETWORK}_${WORLD_SIZE}p_bs${BATCH_SIZE}_e${TOTAL_EPOCHS} echo "[PanoOcc] CASE_NAME = ${CASE_NAME}" # 创建输出目录 -OUTPUT_PATH=./test/output/${CASE_NAME} +OUTPUT_PATH=./output/${CASE_NAME} if [ -d ${OUTPUT_PATH} ]; then rm -rf ${OUTPUT_PATH} @@ -33,18 +38,6 @@ echo "[PanoOcc] OUTPUT_PATH = ${OUTPUT_PATH}" # 配置环境变量 -# 设置 device 侧日志登记为 error -msnpureport -g error -d 0 -msnpureport -g error -d 1 -msnpureport -g error -d 2 -msnpureport -g error -d 3 -msnpureport -g error -d 4 -msnpureport -g error -d 5 -msnpureport -g error -d 6 -msnpureport -g error -d 7 -# 关闭 Device 侧 Event 日志 -msnpureport -e disable - # 将 Host 日志输出到串口, 0-关闭/1-开启 export ASCEND_SLOG_PRINT_TO_STDOUT=0 # 设置默认日志级别, 0-debug/1-info/2-warning/3-error @@ -67,6 +60,8 @@ export CPU_AFFINITY_CONF=1 # 设置是否开启 combined 标志, 0-关闭/1-开启 export COMBINED_ENABLE=1 +#减少显存占用 +export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True" # 修改配置文件中的 total_epochs sed -i "s|total_epochs = .*|total_epochs = ${TOTAL_EPOCHS}|g" ./projects/configs/PanoOcc/Panoptic/PanoOcc_base_4f.py @@ -76,7 +71,8 @@ start_time=$(date +%s) # 开始训练 echo "[PanoOcc] Training..." -bash ./tools/dist_train.sh ./projects/configs/PanoOcc/Panoptic/PanoOcc_base_4f.py ${WORLD_SIZE} ${OUTPUT_PATH}/work_dir > ${OUTPUT_PATH}/train.log 2>&1 & +#bash ./migrate_to_ascend/dist_train.sh ./projects/configs/PanoOcc/Panoptic/PanoOcc_base_4f.py ${WORLD_SIZE} ${OUTPUT_PATH}/work_dir > ${OUTPUT_PATH}/train.log 2>&1 & +bash ./migrate_to_ascend/dist_train.sh ./projects/configs/PanoOcc/Panoptic/PanoOcc_base_4f.py ${WORLD_SIZE} ${OUTPUT_PATH}/work_dir wait # 训练结束时间 diff --git a/model_examples/PanoOcc/mmcv.patch b/model_examples/PanoOcc/mmcv.patch deleted file mode 100644 index 243223c1dd6a8bfcebad00f7bd1022dbdaed9009..0000000000000000000000000000000000000000 --- a/model_examples/PanoOcc/mmcv.patch +++ /dev/null @@ -1,582 +0,0 @@ -diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py -index 8a348e83..58709c97 100644 ---- a/mmcv/ops/modulated_deform_conv.py -+++ b/mmcv/ops/modulated_deform_conv.py -@@ -1,248 +1,34 @@ - # Copyright (c) OpenMMLab. All rights reserved. -+# Copyright 2024 Huawei Technologies Co., Ltd - import math - from typing import Optional, Tuple, Union - - import torch -+import torch_npu - import torch.nn as nn --from torch.autograd import Function --from torch.autograd.function import once_differentiable - from torch.nn.modules.utils import _pair, _single -+from mmcv.utils import deprecated_api_warning -+from mx_driving import modulated_deform_conv2d, ModulatedDeformConv2dFunction - --from mmcv.utils import IS_MLU_AVAILABLE, deprecated_api_warning - from ..cnn import CONV_LAYERS --from ..utils import ext_loader, print_log -- --ext_module = ext_loader.load_ext( -- '_ext', -- ['modulated_deform_conv_forward', 'modulated_deform_conv_backward']) -- -- --class ModulatedDeformConv2dFunction(Function): -- -- @staticmethod -- def symbolic(g, input, offset, mask, weight, bias, stride, padding, -- dilation, groups, deform_groups): -- input_tensors = [input, offset, mask, weight] -- if bias is not None: -- input_tensors.append(bias) -- return g.op( -- 'mmcv::MMCVModulatedDeformConv2d', -- *input_tensors, -- stride_i=stride, -- padding_i=padding, -- dilation_i=dilation, -- groups_i=groups, -- deform_groups_i=deform_groups) -- -- @staticmethod -- def _calculate_sort_index(kernel_h, kernel_w, deformable_group): -- split_num = deformable_group * 2 * kernel_h * kernel_w -- sort_index = list(range(split_num)) -- sort_index_fp = (sort_index[1::2] + sort_index[::2]) -- sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index_fp)} -- sort_index_bp = [sort_index_bp_dict[i] for i in sort_index] -- sort_index_fp = torch.IntTensor(sort_index_fp) -- sort_index_bp = torch.IntTensor(sort_index_bp) -- sort_index_fp = sort_index_fp.npu() -- sort_index_bp = sort_index_bp.npu() -- return sort_index_fp, sort_index_bp -- -- @staticmethod -- def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): -- _, _, kernel_h, kernel_w = weight.shape -- conv2d_bias = bias if len(bias) > 0 else None -- sort_index_fp, sort_index_bp = \ -- ModulatedDeformConv2dFunction._calculate_sort_index( -- kernel_w, kernel_h, ctx.deform_groups) -- select_offset = offset.index_select(1, sort_index_fp) -- offset_all = torch.cat([select_offset, mask], dim=1) -- import torch_npu -- output, offset_out = torch_npu.npu_deformable_conv2d( -- input_tensor, -- weight, -- offset_all, -- conv2d_bias, -- kernel_size=[kernel_w, kernel_h], -- stride=[1, 1, ctx.stride[0], ctx.stride[1]], -- padding=[ -- ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1] -- ], -- dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], -- groups=ctx.groups, -- deformable_groups=ctx.deform_groups, -- modulated=True) -- if weight.requires_grad or mask.requires_grad or offset.requires_grad \ -- or input_tensor.requires_grad: -- ctx.save_for_backward(input_tensor, weight, offset_out, offset_all, -- sort_index_bp) -- return output -- -- @staticmethod -- def _npu_backward(ctx, grad_output): -- input_tensor, weight, offset_out, offset_all, sort_index_bp = \ -- ctx.saved_tensors -- import torch_npu -- grad_input, grad_weight, grad_offset_all, grad_bias = \ -- torch_npu.npu_deformable_conv2dbk( -- input_tensor, grad_output, offset_out, weight, offset_all, -- kernel_size=[weight.shape[3], weight.shape[2]], -- stride=[1, 1, ctx.stride[0], ctx.stride[1]], -- padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1], -- ctx.padding[1]], -- dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], -- groups=ctx.groups, deformable_groups=ctx.deform_groups, -- modulated=True) -- grad_offset = grad_offset_all.index_select(1, sort_index_bp) -- grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :] -- if not ctx.with_bias: -- grad_bias = None -- return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, -- None, None, None, None, None, None, None, None) -- -- @staticmethod -- def forward(ctx, -- input: torch.Tensor, -- offset: torch.Tensor, -- mask: torch.Tensor, -- weight: nn.Parameter, -- bias: Optional[nn.Parameter] = None, -- stride: int = 1, -- padding: int = 0, -- dilation: int = 1, -- groups: int = 1, -- deform_groups: int = 1) -> torch.Tensor: -- if input is not None and input.dim() != 4: -- raise ValueError( -- f'Expected 4D tensor as input, got {input.dim()}D tensor \ -- instead.') -- ctx.stride = _pair(stride) -- ctx.padding = _pair(padding) -- ctx.dilation = _pair(dilation) -- ctx.groups = groups -- ctx.deform_groups = deform_groups -- ctx.with_bias = bias is not None -- ctx.device = input.device.type -- if not ctx.with_bias: -- bias = input.new_empty(0) # fake tensor -- # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; -- # amp won't cast the type of model (float32), but "offset" is cast -- # to float16 by nn.Conv2d automatically, leading to the type -- # mismatch with input (when it is float32) or weight. -- # The flag for whether to use fp16 or amp is the type of "offset", -- # we cast weight and input to temporarily support fp16 and amp -- # whatever the pytorch version is. -- input = input.type_as(offset) -- weight = weight.type_as(input) -- bias = bias.type_as(input) # type: ignore -- mask = mask.type_as(input) -- if ctx.device == 'npu': -- output = ModulatedDeformConv2dFunction._npu_forward( -- ctx, input, offset, mask, weight, bias) -- return output -- ctx.save_for_backward(input, offset, mask, weight, bias) -- output = input.new_empty( -- ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) -- ctx._bufs = [input.new_empty(0), input.new_empty(0)] -- ext_module.modulated_deform_conv_forward( -- input, -- weight, -- bias, -- ctx._bufs[0], -- offset, -- mask, -- output, -- ctx._bufs[1], -- kernel_h=weight.size(2), -- kernel_w=weight.size(3), -- stride_h=ctx.stride[0], -- stride_w=ctx.stride[1], -- pad_h=ctx.padding[0], -- pad_w=ctx.padding[1], -- dilation_h=ctx.dilation[0], -- dilation_w=ctx.dilation[1], -- group=ctx.groups, -- deformable_group=ctx.deform_groups, -- with_bias=ctx.with_bias) -- return output -- -- @staticmethod -- @once_differentiable -- def backward(ctx, grad_output: torch.Tensor) -> tuple: -- if ctx.device == 'npu': -- return ModulatedDeformConv2dFunction._npu_backward( -- ctx, grad_output) -- input, offset, mask, weight, bias = ctx.saved_tensors -- grad_input = torch.zeros_like(input) -- grad_offset = torch.zeros_like(offset) -- grad_mask = torch.zeros_like(mask) -- grad_weight = torch.zeros_like(weight) -- grad_bias = torch.zeros_like(bias) -- grad_output = grad_output.contiguous() -- ext_module.modulated_deform_conv_backward( -- input, -- weight, -- bias, -- ctx._bufs[0], -- offset, -- mask, -- ctx._bufs[1], -- grad_input, -- grad_weight, -- grad_bias, -- grad_offset, -- grad_mask, -- grad_output, -- kernel_h=weight.size(2), -- kernel_w=weight.size(3), -- stride_h=ctx.stride[0], -- stride_w=ctx.stride[1], -- pad_h=ctx.padding[0], -- pad_w=ctx.padding[1], -- dilation_h=ctx.dilation[0], -- dilation_w=ctx.dilation[1], -- group=ctx.groups, -- deformable_group=ctx.deform_groups, -- with_bias=ctx.with_bias) -- if not ctx.with_bias: -- grad_bias = None -- -- return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, -- None, None, None, None, None) -- -- @staticmethod -- def _output_size(ctx, input, weight): -- channels = weight.size(0) -- output_size = (input.size(0), channels) -- for d in range(input.dim() - 2): -- in_size = input.size(d + 2) -- pad = ctx.padding[d] -- kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1 -- stride_ = ctx.stride[d] -- output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) -- if not all(map(lambda s: s > 0, output_size)): -- raise ValueError( -- 'convolution input is too small (output would be ' + -- 'x'.join(map(str, output_size)) + ')') -- return output_size -- -- --modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply -+from ..utils import print_log - - - class ModulatedDeformConv2d(nn.Module): - -- @deprecated_api_warning({'deformable_groups': 'deform_groups'}, -- cls_name='ModulatedDeformConv2d') -- def __init__(self, -- in_channels: int, -- out_channels: int, -- kernel_size: Union[int, Tuple[int]], -- stride: int = 1, -- padding: int = 0, -- dilation: int = 1, -- groups: int = 1, -- deform_groups: int = 1, -- bias: Union[bool, str] = True): -+ @deprecated_api_warning({"deformable_groups": "deform_groups"}, cls_name="ModulatedDeformConv2d") -+ def __init__( -+ self, -+ in_channels: int, -+ out_channels: int, -+ kernel_size: Union[int, Tuple[int]], -+ stride: int = 1, -+ padding: int = 0, -+ dilation: int = 1, -+ groups: int = 1, -+ deform_groups: int = 1, -+ bias: Union[bool, str] = True, -+ ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels -@@ -256,33 +42,38 @@ class ModulatedDeformConv2d(nn.Module): - self.transposed = False - self.output_padding = _single(0) - -- self.weight = nn.Parameter( -- torch.Tensor(out_channels, in_channels // groups, -- *self.kernel_size)) -+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - else: -- self.register_parameter('bias', None) -+ self.register_parameter("bias", None) - self.init_weights() - - def init_weights(self): - n = self.in_channels - for k in self.kernel_size: - n *= k -- stdv = 1. / math.sqrt(n) -+ stdv = 1.0 / math.sqrt(n) - self.weight.data.uniform_(-stdv, stdv) - if self.bias is not None: - self.bias.data.zero_() - -- def forward(self, x: torch.Tensor, offset: torch.Tensor, -- mask: torch.Tensor) -> torch.Tensor: -- return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, -- self.stride, self.padding, -- self.dilation, self.groups, -- self.deform_groups) -+ def forward(self, x: torch.Tensor, offset: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: -+ return modulated_deform_conv2d( -+ x, -+ offset, -+ mask, -+ self.weight, -+ self.bias, -+ self.stride, -+ self.padding, -+ self.dilation, -+ self.groups, -+ self.deform_groups, -+ ) - - --@CONV_LAYERS.register_module('DCNv2') -+@CONV_LAYERS.register_module("DCNv2") - class ModulatedDeformConv2dPack(ModulatedDeformConv2d): - """A ModulatedDeformable Conv Encapsulation that acts as normal Conv - layers. -@@ -311,115 +102,53 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d): - stride=self.stride, - padding=self.padding, - dilation=self.dilation, -- bias=True) -+ bias=True, -+ ) - self.init_weights() - - def init_weights(self) -> None: - super().init_weights() -- if hasattr(self, 'conv_offset'): -+ if hasattr(self, "conv_offset"): - self.conv_offset.weight.data.zero_() - self.conv_offset.bias.data.zero_() - - def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore - out = self.conv_offset(x) -- o1, o2, mask = torch.chunk(out, 3, dim=1) -- offset = torch.cat((o1, o2), dim=1) -+ len1 = ((out.shape[1] + 2) // 3) * 2 -+ len2 = out.shape[1] - len1 -+ offset, mask = torch.split(out, [len1, len2], dim=1) - mask = torch.sigmoid(mask) -- return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, -- self.stride, self.padding, -- self.dilation, self.groups, -- self.deform_groups) -- -- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, -- missing_keys, unexpected_keys, error_msgs): -- version = local_metadata.get('version', None) -+ return modulated_deform_conv2d( -+ x, -+ offset, -+ mask, -+ self.weight, -+ self.bias, -+ self.stride, -+ self.padding, -+ self.dilation, -+ self.groups, -+ self.deform_groups, -+ ) -+ -+ # pylint: disable=huawei-too-many-arguments -+ def _load_from_state_dict( -+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs -+ ): -+ version = local_metadata.get("version", None) - - if version is None or version < 2: - # the key is different in early versions - # In version < 2, ModulatedDeformConvPack - # loads previous benchmark models. -- if (prefix + 'conv_offset.weight' not in state_dict -- and prefix[:-1] + '_offset.weight' in state_dict): -- state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( -- prefix[:-1] + '_offset.weight') -- if (prefix + 'conv_offset.bias' not in state_dict -- and prefix[:-1] + '_offset.bias' in state_dict): -- state_dict[prefix + -- 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + -- '_offset.bias') -+ if prefix + "conv_offset.weight" not in state_dict and prefix[:-1] + "_offset.weight" in state_dict: -+ state_dict[prefix + "conv_offset.weight"] = state_dict.pop(prefix[:-1] + "_offset.weight") -+ if prefix + "conv_offset.bias" not in state_dict and prefix[:-1] + "_offset.bias" in state_dict: -+ state_dict[prefix + "conv_offset.bias"] = state_dict.pop(prefix[:-1] + "_offset.bias") - - if version is not None and version > 1: -- print_log( -- f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to ' -- 'version 2.', -- logger='root') -- -- super()._load_from_state_dict(state_dict, prefix, local_metadata, -- strict, missing_keys, unexpected_keys, -- error_msgs) -- -- --if IS_MLU_AVAILABLE: -- import torchvision -- from torchvision.ops import deform_conv2d as tv_deform_conv2d -- -- from mmcv.utils import digit_version -- -- @CONV_LAYERS.register_module('DCNv2', force=True) -- class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d): -- """This class is the DCNv2 implementation of the MLU device. The MLU -- backend support of the operator has been implemented in torchvision. -- The mmcv registration mechanism is used for multiplexing here. The -- torchvision implementation of DCNv2 is called. -- -- Args: -- in_channels (int): Same as nn.Conv2d. -- out_channels (int): Same as nn.Conv2d. -- kernel_size (int or tuple[int]): Same as nn.Conv2d. -- stride (int): Same as nn.Conv2d, while tuple is not supported. -- padding (int): Same as nn.Conv2d, while tuple is not supported. -- dilation (int): Same as nn.Conv2d, while tuple is not supported. -- groups (int): Same as nn.Conv2d. -- bias (bool or str): If specified as `auto`, it will be decided by -- the norm_cfg. Bias will be set as True if norm_cfg is None, -- otherwise False. -- """ -- -- def __init__(self, *args, **kwargs): -- assert digit_version(torchvision.__version__) >= digit_version( -- '0.10.0a0'), 'the version of torchvision should be >= 0.10.0' -- super().__init__(*args, **kwargs) -- self.conv_offset = nn.Conv2d( -- self.in_channels, -- self.deform_groups * 3 * self.kernel_size[0] * -- self.kernel_size[1], -- kernel_size=self.kernel_size, -- stride=self.stride, -- padding=self.padding, -- dilation=self.dilation, -- bias=True) -- self.init_weights() -- -- def init_weights(self): -- super().init_weights() -- if hasattr(self, 'conv_offset'): -- self.conv_offset.weight.data.zero_() -- self.conv_offset.bias.data.zero_() -+ print_log(f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to ' "version 2.", logger="root") - -- def forward(self, x): -- out = self.conv_offset(x) -- o1, o2, mask = torch.chunk(out, 3, dim=1) -- offset = torch.cat((o1, o2), dim=1) -- mask = torch.sigmoid(mask) -- x = x.type_as(offset) -- weight = self.weight.type_as(x) -- mask = mask.type_as(x) -- return tv_deform_conv2d( -- x, -- offset, -- weight, -- bias=self.bias, -- stride=self.stride, -- padding=self.padding, -- dilation=self.dilation, -- mask=mask) -+ super()._load_from_state_dict( -+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs -+ ) -diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py -index bf34cb59..43af486c 100644 ---- a/mmcv/parallel/distributed.py -+++ b/mmcv/parallel/distributed.py -@@ -1,4 +1,5 @@ - # Copyright (c) OpenMMLab. All rights reserved. -+# Copyright 2024 Huawei Technologies Co., Ltd - from typing import Any, List, Tuple - - import torch -@@ -156,8 +157,7 @@ class MMDistributedDataParallel(DistributedDataParallel): - Returns: - Any: Forward result of :attr:`module`. - """ -- module_to_run = self._replicated_tensor_module if \ -- self._use_replicated_tensor_module else self.module -+ module_to_run = self.module - - if self.device_ids: - inputs, kwargs = self.to_kwargs( # type: ignore -diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py -index 93015475..e45a8f2b 100644 ---- a/mmcv/runner/hooks/optimizer.py -+++ b/mmcv/runner/hooks/optimizer.py -@@ -1,4 +1,19 @@ -+# Copyright 2024 Huawei Technologies Co., Ltd -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ - # Copyright (c) OpenMMLab. All rights reserved. -+# Copyright 2024 Huawei Technologies Co., Ltd - import copy - import logging - from collections import defaultdict -@@ -52,11 +67,11 @@ class OptimizerHook(Hook): - self.grad_clip = grad_clip - self.detect_anomalous_params = detect_anomalous_params - -- def clip_grads(self, params): -+ def clip_grads(self, params, runner): - params = list( - filter(lambda p: p.requires_grad and p.grad is not None, params)) - if len(params) > 0: -- return clip_grad.clip_grad_norm_(params, **self.grad_clip) -+ return runner.optimizer.clip_grad_norm_fused_(**self.grad_clip) - - def after_train_iter(self, runner): - runner.optimizer.zero_grad() -@@ -65,7 +80,7 @@ class OptimizerHook(Hook): - runner.outputs['loss'].backward() - - if self.grad_clip is not None: -- grad_norm = self.clip_grads(runner.model.parameters()) -+ grad_norm = self.clip_grads(runner.model.parameters(), runner) - if grad_norm is not None: - # Add grad norm to the logger - runner.log_buffer.update({'grad_norm': float(grad_norm)}, -@@ -182,7 +197,7 @@ class GradientCumulativeOptimizerHook(OptimizerHook): - or self.is_last_iter(runner)): - - if self.grad_clip is not None: -- grad_norm = self.clip_grads(runner.model.parameters()) -+ grad_norm = self.clip_grads(runner.model.parameters(), runner) - if grad_norm is not None: - # Add grad norm to the logger - runner.log_buffer.update({'grad_norm': float(grad_norm)}, -@@ -291,7 +306,7 @@ if (TORCH_VERSION != 'parrots' - self.loss_scaler.unscale_(runner.optimizer) - # grad clip - if self.grad_clip is not None: -- grad_norm = self.clip_grads(runner.model.parameters()) -+ grad_norm = self.clip_grads(runner.model.parameters(), runner) - if grad_norm is not None: - # Add grad norm to the logger - runner.log_buffer.update({'grad_norm': float(grad_norm)}, -@@ -331,7 +346,7 @@ if (TORCH_VERSION != 'parrots' - self.loss_scaler.unscale_(runner.optimizer) - - if self.grad_clip is not None: -- grad_norm = self.clip_grads(runner.model.parameters()) -+ grad_norm = self.clip_grads(runner.model.parameters(), runner) - if grad_norm is not None: - # Add grad norm to the logger - runner.log_buffer.update( -@@ -477,7 +492,7 @@ else: - if param.grad is not None: - param.grad.div_(self.loss_scaler.loss_scale) - if self.grad_clip is not None: -- grad_norm = self.clip_grads(fp32_weights) -+ grad_norm = self.clip_grads(fp32_weights, runner) - if grad_norm is not None: - # Add grad norm to the logger - runner.log_buffer.update( -@@ -534,7 +549,7 @@ else: - if param.grad is not None: - param.grad.div_(self.loss_scaler.loss_scale) - if self.grad_clip is not None: -- grad_norm = self.clip_grads(fp32_weights) -+ grad_norm = self.clip_grads(fp32_weights, runner) - if grad_norm is not None: - # Add grad norm to the logger - runner.log_buffer.update( diff --git a/model_examples/PanoOcc/mmdetection.patch b/model_examples/PanoOcc/mmdetection.patch deleted file mode 100644 index 74627ae58b07551f0185ee32d22ad053f89c6e3e..0000000000000000000000000000000000000000 --- a/model_examples/PanoOcc/mmdetection.patch +++ /dev/null @@ -1,74 +0,0 @@ -diff --git a/mmdet/__init__.py b/mmdet/__init__.py -index 1f8ee169..d9e6ba13 100644 ---- a/mmdet/__init__.py -+++ b/mmdet/__init__.py -@@ -1,4 +1,5 @@ - # Copyright (c) OpenMMLab. All rights reserved. -+# Copyright 2024 Huawei Technologies Co., Ltd - import mmcv - - from .version import __version__, short_version -@@ -17,7 +18,7 @@ def digit_version(version_str): - - - mmcv_minimum_version = '1.3.17' --mmcv_maximum_version = '1.6.0' -+mmcv_maximum_version = '1.7.2' - mmcv_version = digit_version(mmcv.__version__) - - -diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py -index 1eaaae67..9d224e54 100644 ---- a/mmdet/models/backbones/resnet.py -+++ b/mmdet/models/backbones/resnet.py -@@ -1,4 +1,5 @@ - # Copyright (c) OpenMMLab. All rights reserved. -+# Copyright 2024 Huawei Technologies Co., Ltd - import warnings - - import torch.nn as nn -@@ -6,6 +7,9 @@ import torch.utils.checkpoint as cp - from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer - from mmcv.runner import BaseModule - from torch.nn.modules.batchnorm import _BatchNorm -+import mx_driving -+import torch -+import torch_npu - - from ..builder import BACKBONES - from ..utils import ResLayer -@@ -288,7 +292,7 @@ class Bottleneck(BaseModule): - if self.downsample is not None: - identity = self.downsample(x) - -- out += identity -+ out = mx_driving.npu_add_relu(out, identity) - - return out - -@@ -297,8 +301,6 @@ class Bottleneck(BaseModule): - else: - out = _inner_forward(x) - -- out = self.relu(out) -- - return out - - -@@ -608,7 +610,6 @@ class ResNet(BaseModule): - self.norm_cfg, stem_channels, postfix=1) - self.add_module(self.norm1_name, norm1) - self.relu = nn.ReLU(inplace=True) -- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - - def _freeze_stages(self): - if self.frozen_stages >= 0: -@@ -636,7 +637,7 @@ class ResNet(BaseModule): - x = self.conv1(x) - x = self.norm1(x) - x = self.relu(x) -- x = self.maxpool(x) -+ x = mx_driving.npu_max_pool2d(x, 3, 2, 1) - outs = [] - for i, layer_name in enumerate(self.res_layers): - res_layer = getattr(self, layer_name) diff --git a/model_examples/PanoOcc/panoocc.patch b/model_examples/PanoOcc/panoocc.patch deleted file mode 100644 index 282551419d1602a69f55d0f289074d141ed6518b..0000000000000000000000000000000000000000 --- a/model_examples/PanoOcc/panoocc.patch +++ /dev/null @@ -1,1443 +0,0 @@ -diff --git a/projects/configs/PanoOcc/Panoptic/PanoOcc_base_4f.py b/projects/configs/PanoOcc/Panoptic/PanoOcc_base_4f.py -index 58a49f0..6c5e8e1 100644 ---- a/projects/configs/PanoOcc/Panoptic/PanoOcc_base_4f.py -+++ b/projects/configs/PanoOcc/Panoptic/PanoOcc_base_4f.py -@@ -307,7 +307,7 @@ data = dict( - ) - - optimizer = dict( -- type='AdamW', -+ type='NpuFusedAdamW', - lr=4e-4, - paramwise_cfg=dict( - custom_keys={ -diff --git a/projects/configs/_base_/default_runtime.py b/projects/configs/_base_/default_runtime.py -index 4e85b69..cd301c6 100644 ---- a/projects/configs/_base_/default_runtime.py -+++ b/projects/configs/_base_/default_runtime.py -@@ -10,7 +10,7 @@ log_config = dict( - dict(type='TensorboardLoggerHook') - ]) - # yapf:enable --dist_params = dict(backend='nccl') -+dist_params = dict(backend='hccl') - log_level = 'INFO' - work_dir = None - load_from = None -diff --git a/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py b/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py -index e57bd22..fc4600b 100644 ---- a/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py -+++ b/projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py -@@ -9,7 +9,7 @@ import warnings - import numpy as np - import torch - import torch.distributed as dist --from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -+from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel - from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner, - Fp16OptimizerHook, OptimizerHook, build_optimizer, - build_runner, get_dist_info) -@@ -64,6 +64,7 @@ def custom_train_detector(model, - seed=cfg.seed, - shuffler_sampler=cfg.data.shuffler_sampler, # dict(type='DistributedGroupSampler'), - nonshuffler_sampler=cfg.data.nonshuffler_sampler, # dict(type='DistributedSampler'), -+ pin_memory=True - ) for ds in dataset - ] - -@@ -72,22 +73,22 @@ def custom_train_detector(model, - find_unused_parameters = cfg.get('find_unused_parameters', False) - # Sets the `find_unused_parameters` parameter in - # torch.nn.parallel.DistributedDataParallel -- model = MMDistributedDataParallel( -+ model = NPUDistributedDataParallel( - model.cuda(), - device_ids=[torch.cuda.current_device()], - broadcast_buffers=False, - find_unused_parameters=find_unused_parameters) - if eval_model is not None: -- eval_model = MMDistributedDataParallel( -+ eval_model = NPUDistributedDataParallel( - eval_model.cuda(), - device_ids=[torch.cuda.current_device()], - broadcast_buffers=False, - find_unused_parameters=find_unused_parameters) - else: -- model = MMDataParallel( -+ model = NPUDataParallel( - model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) - if eval_model is not None: -- eval_model = MMDataParallel( -+ eval_model = NPUDataParallel( - eval_model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) - - -diff --git a/projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py b/projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py -index a831372..8d9815a 100644 ---- a/projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py -+++ b/projects/mmdet3d_plugin/bevformer/dense_heads/__init__.py -@@ -1,3 +1,2 @@ - from .pano_occ_head import PanoOccHead --from .panoseg_occ_head import PanoSegOccHead --from .panoseg_occ_sparse_head import SparseOccupancyHead -\ No newline at end of file -+from .panoseg_occ_head import PanoSegOccHead -\ No newline at end of file -diff --git a/projects/mmdet3d_plugin/bevformer/dense_heads/panoseg_occ_head.py b/projects/mmdet3d_plugin/bevformer/dense_heads/panoseg_occ_head.py -index cc986f0..aadce74 100644 ---- a/projects/mmdet3d_plugin/bevformer/dense_heads/panoseg_occ_head.py -+++ b/projects/mmdet3d_plugin/bevformer/dense_heads/panoseg_occ_head.py -@@ -18,10 +18,42 @@ import numpy as np - import mmcv - import cv2 as cv - from projects.mmdet3d_plugin.models.utils.visual import save_tensor --from mmdet3d.ops import scatter_v2 --import torch_scatter -+import mx_driving -+import mx_driving._C - from mmdet.models.builder import build_loss - -+ -+def custom_unique_n3(coors, return_inverse, return_counts, dim): -+ # assert dim == 0 -+ -+ voxels = mx_driving._C.point_to_voxel(coors, [], [], "ZYX") -+ cnt, unq_voxels, unq_ind, argsort_ind, _ = mx_driving._C.unique_voxel(voxels) -+ unq_coors = mx_driving._C.voxel_to_point(unq_voxels, [], [], "ZYX") -+ -+ if return_inverse: -+ sorted_ind = torch.argsort(argsort_ind.to(torch.float32), dim=dim).to(torch.long) -+ is_unq = torch.zeros(coors.size(0)).to(coors.device) -+ is_unq[unq_ind] = 1 -+ unq_inv_sorted = is_unq.cumsum(dim) - 1 -+ unq_inv = torch.gather(unq_inv_sorted, dim, sorted_ind) -+ unq_inv = unq_inv.to(torch.int64) -+ -+ if return_counts: -+ unq_ind_nxt = torch.ones_like(unq_ind) * coors.size(0) -+ unq_ind_nxt[:-1] = unq_ind[1:] -+ unq_cnt = unq_ind_nxt - unq_ind -+ unq_cnt = unq_cnt.to(torch.int64) -+ -+ if return_inverse and return_counts: -+ return unq_coors, unq_inv, unq_cnt -+ elif return_inverse: -+ return unq_coors, unq_inv -+ elif return_counts: -+ return unq_coors, unq_cnt -+ else: -+ return unq_coors -+ -+ - @HEADS.register_module() - class PanoSegOccHead(DETRHead): - def __init__(self, -@@ -35,22 +67,22 @@ class PanoSegOccHead(DETRHead): - bev_h=30, - bev_w=30, - bev_z=5, -- voxel_lidar = [0.05, 0.05, 0.05], -- voxel_det = [2.048,2.048,1], -+ voxel_lidar=[0.05, 0.05, 0.05], -+ voxel_det=[2.048,2.048,1], - loss_occupancy=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=5.0), -- loss_occupancy_aux = None, -+ loss_occupancy_aux=None, - loss_occupancy_det=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=5.0), -- bg_weight = 0.02, -+ bg_weight=0.02, - **kwargs): - - self.bev_h = bev_h -@@ -88,6 +120,13 @@ class PanoSegOccHead(DETRHead): - if loss_occupancy_aux is not None: - self.lidar_seg_aux_loss = build_loss(loss_occupancy_aux) - -+ self.pc_range = nn.Parameter(torch.tensor( -+ self.pc_range, requires_grad=False), requires_grad=False) -+ self.voxel_lidar = nn.Parameter(torch.tensor( -+ self.voxel_lidar, requires_grad=False), requires_grad=False) -+ self.voxel_det = nn.Parameter(torch.tensor( -+ self.voxel_det, requires_grad=False), requires_grad=False) -+ - def _init_layers(self): - """Initialize classification branch and regression branch of head.""" - cls_branch = [] -@@ -159,7 +198,7 @@ class PanoSegOccHead(DETRHead): - object_query_embeds = self.query_embedding.weight.to(dtype) - bev_queries = self.bev_embedding.weight.to(dtype) - -- bev_mask = torch.zeros((bs, self.bev_h, self.bev_w, self.bev_z),device=bev_queries.device).to(dtype) -+ bev_mask = torch.zeros((bs, self.bev_h, self.bev_w, self.bev_z), device=bev_queries.device, dtype=dtype) - bev_pos = self.positional_encoding(bev_mask).to(dtype) - - if only_bev: -@@ -180,21 +219,21 @@ class PanoSegOccHead(DETRHead): - ) - bev_feat, bev_embed, hs, init_reference, inter_references, occupancy, occupancy_det = outputs - return bev_feat, bev_embed -- else: -- outputs = self.transformer( -- mlvl_feats, -- bev_queries, -- object_query_embeds, -- self.bev_h, -- self.bev_w, -- self.bev_z, -- grid_length=(self.real_h / self.bev_h, -- self.real_w / self.bev_w), -- bev_pos=bev_pos, -- reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 -- cls_branches=self.cls_branches if self.as_two_stage else None, -- img_metas=img_metas, -- prev_bev=prev_bev -+ -+ outputs = self.transformer( -+ mlvl_feats, -+ bev_queries, -+ object_query_embeds, -+ self.bev_h, -+ self.bev_w, -+ self.bev_z, -+ grid_length=(self.real_h / self.bev_h, -+ self.real_w / self.bev_w), -+ bev_pos=bev_pos, -+ reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 -+ cls_branches=self.cls_branches if self.as_two_stage else None, -+ img_metas=img_metas, -+ prev_bev=prev_bev - ) - - bev_feat, bev_embed, hs, init_reference, inter_references, occupancy, occupancy_det = outputs -@@ -211,7 +250,7 @@ class PanoSegOccHead(DETRHead): - tmp = self.reg_branches[lvl](hs[lvl]) - - # TODO: check the shape of reference -- assert reference.shape[-1] == 3 -+ # assert reference.shape[-1] == 3 - tmp[..., 0:2] += reference[..., 0:2] - tmp[..., 0:2] = tmp[..., 0:2].sigmoid() - tmp[..., 4:5] += reference[..., 2:3] -@@ -279,7 +318,7 @@ class PanoSegOccHead(DETRHead): - gt_c = gt_bboxes.shape[-1] - - assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, -- gt_labels, gt_bboxes_ignore) -+ gt_labels) - - sampling_result = self.sampler.sample(assign_result, bbox_pred, - gt_bboxes) -@@ -338,17 +377,11 @@ class PanoSegOccHead(DETRHead): - - num_total_neg (int): Number of negative samples in all \ - images. - """ -- assert gt_bboxes_ignore_list is None, \ -- 'Only supports for gt_bboxes_ignore setting to None.' -- num_imgs = len(cls_scores_list) -- gt_bboxes_ignore_list = [ -- gt_bboxes_ignore_list for _ in range(num_imgs) -- ] - - (labels_list, label_weights_list, bbox_targets_list, - bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply( - self._get_target_single, cls_scores_list, bbox_preds_list, -- gt_labels_list, gt_bboxes_list, gt_bboxes_ignore_list) -+ gt_labels_list, gt_bboxes_list) - num_total_pos = sum((inds.numel() for inds in pos_inds_list)) - num_total_neg = sum((inds.numel() for inds in neg_inds_list)) - return (labels_list, label_weights_list, bbox_targets_list, -@@ -382,8 +415,7 @@ class PanoSegOccHead(DETRHead): - cls_scores_list = [cls_scores[i] for i in range(num_imgs)] - bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] - cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, -- gt_bboxes_list, gt_labels_list, -- gt_bboxes_ignore_list) -+ gt_bboxes_list, gt_labels_list) - (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, - num_total_pos, num_total_neg) = cls_reg_targets - labels = torch.cat(labels_list, 0) -@@ -394,11 +426,9 @@ class PanoSegOccHead(DETRHead): - # classification loss - cls_scores = cls_scores.reshape(-1, self.cls_out_channels) - # construct weighted avg_factor to match with the official DETR repo -- cls_avg_factor = num_total_pos * 1.0 + \ -- num_total_neg * self.bg_cls_weight -+ cls_avg_factor = num_total_pos * 1.0 + num_total_neg * self.bg_cls_weight - if self.sync_cls_avg_factor: -- cls_avg_factor = reduce_mean( -- cls_scores.new_tensor([cls_avg_factor])) -+ cls_avg_factor = reduce_mean(cls_scores.new_tensor([cls_avg_factor])) - - cls_avg_factor = max(cls_avg_factor, 1) - -@@ -417,7 +447,7 @@ class PanoSegOccHead(DETRHead): - bbox_weights = bbox_weights * self.code_weights - - loss_bbox = self.loss_bbox( -- bbox_preds[isnotnan, :10], normalized_bbox_targets[isnotnan,:10], bbox_weights[isnotnan, :10], -+ bbox_preds[isnotnan, :10], normalized_bbox_targets[isnotnan, :10], bbox_weights[isnotnan, :10], - avg_factor=num_total_pos) - if digit_version(TORCH_VERSION) >= digit_version('1.8'): - loss_cls = torch.nan_to_num(loss_cls) -@@ -426,13 +456,13 @@ class PanoSegOccHead(DETRHead): - - def get_occupancy_det_label(self,voxel_coors_det, voxel_label_det, occupancy_det_label): - -- voxel_coors_det[:,1] = voxel_coors_det[:,1].clip(min=0,max=self.bev_z-1) -- voxel_coors_det[:,2] = voxel_coors_det[:,2].clip(min=0,max=self.bev_h-1) -- voxel_coors_det[:,3] = voxel_coors_det[:,3].clip(min=0,max=self.bev_w-1) -+ voxel_coors_det[:, 0] = voxel_coors_det[:, 0].clip(min=0, max=self.bev_z-1) -+ voxel_coors_det[:, 1] = voxel_coors_det[:, 1].clip(min=0, max=self.bev_h-1) -+ voxel_coors_det[:, 2] = voxel_coors_det[:, 2].clip(min=0, max=self.bev_w-1) - -- det_label_binary = ((voxel_label_det>=1)&(voxel_label_det<=10)) -+ det_label_binary = ((voxel_label_det>=1) & (voxel_label_det<=10)) - det_label = det_label_binary.long() -- occupancy_det_label[0,voxel_coors_det[:,1],voxel_coors_det[:,2],voxel_coors_det[:,3]]=det_label -+ occupancy_det_label[0, voxel_coors_det[:, 0], voxel_coors_det[:, 1], voxel_coors_det[:, 2]] = det_label - return occupancy_det_label - - def get_det_loss(self,voxel_label_det,occupancy_det_label,occupancy_det_pred): -@@ -446,7 +476,7 @@ class PanoSegOccHead(DETRHead): - occupancy_det_pred.new_tensor([avg_factor_det])) - avg_factor_det = max(avg_factor_det, 1) - -- losses_det = self.lidar_det_loss(occupancy_det_pred,occupancy_det_label,avg_factor=avg_factor_det) -+ losses_det = self.lidar_det_loss(occupancy_det_pred, occupancy_det_label, avg_factor=avg_factor_det) - return losses_det - - @force_fp32(apply_to=('preds_dicts')) -@@ -455,7 +485,7 @@ class PanoSegOccHead(DETRHead): - gt_labels_list, - pts_sem, - preds_dicts, -- dense_occupancy = None, -+ dense_occupancy=None, - gt_bboxes_ignore=None, - img_metas=None): - """"Loss function. -@@ -485,9 +515,6 @@ class PanoSegOccHead(DETRHead): - Returns: - dict[str, Tensor]: A dictionary of loss components. - """ -- assert gt_bboxes_ignore is None, \ -- f'{self.__class__.__name__} only supports ' \ -- f'for gt_bboxes_ignore setting to None.' - - # Extract the first three columns from pts_sem - pts = pts_sem[:, :3] -@@ -497,12 +524,10 @@ class PanoSegOccHead(DETRHead): - - # If dense_occupancy is None, perform voxelization and label voxelization - if dense_occupancy is None: -- pts_coors, voxelized_data, voxel_coors = self.voxelize(pts, self.pc_range, self.voxel_lidar) -- voxel_label = self.label_voxelization(pts_semantic_mask, pts_coors, voxel_coors) -+ _, voxel_coors, voxel_label = self.voxelize(pts, self.pc_range, self.voxel_lidar, pts_semantic_mask) - - # Perform voxelization and label voxelization for detection -- pts_coors_det, voxelized_data_det, voxel_coors_det = self.voxelize(pts, self.pc_range, self.voxel_det) -- voxel_label_det = self.label_voxelization(pts_semantic_mask, pts_coors_det, voxel_coors_det) -+ _, voxel_coors_det, voxel_label_det = self.voxelize(pts, self.pc_range, self.voxel_det, pts_semantic_mask) - - all_cls_scores = preds_dicts['all_cls_scores'] - all_bbox_preds = preds_dicts['all_bbox_preds'] -@@ -514,31 +539,31 @@ class PanoSegOccHead(DETRHead): - occupancy_pred = occupancy.squeeze(0) - occupancy_det_pred = occupancy_det[0].squeeze(0) - -- cls_num,occ_z,occ_h,occ_w = occupancy_pred.shape -+ cls_num, occ_z, occ_h, occ_w = occupancy_pred.shape - if dense_occupancy is None: - occupancy_label = torch.full((1, occ_z, occ_h, occ_w), cls_num, device=occupancy_pred.device, dtype=torch.long) - else: -- occupancy_label = (torch.zeros(1,occ_z,occ_h,occ_w)).to(occupancy_pred.device).long() -+ occupancy_label = (torch.zeros(1, occ_z, occ_h, occ_w)).to(occupancy_pred.device).long() - -- occupancy_det_label = (torch.ones(1,self.bev_z,self.bev_h,self.bev_w)*2).to(occupancy_det_pred.device).long() -+ occupancy_det_label = (torch.ones(1, self.bev_z, self.bev_h, self.bev_w) * 2).to(occupancy_det_pred.device).long() - - if dense_occupancy is None: -- voxel_coors[:,1] = voxel_coors[:,1].clip(min=0,max=occ_z-1) -- voxel_coors[:,2] = voxel_coors[:,2].clip(min=0,max=occ_h-1) -- voxel_coors[:,3] = voxel_coors[:,3].clip(min=0,max=occ_w-1) -- occupancy_label[0,voxel_coors[:,1],voxel_coors[:,2],voxel_coors[:,3]] = voxel_label -+ voxel_coors[:, 0] = voxel_coors[:, 0].clip(min=0, max=occ_z-1) -+ voxel_coors[:, 1] = voxel_coors[:, 1].clip(min=0, max=occ_h-1) -+ voxel_coors[:, 2] = voxel_coors[:, 2].clip(min=0, max=occ_w-1) -+ occupancy_label[0, voxel_coors[:, 0], voxel_coors[:, 1], voxel_coors[:, 2]] = voxel_label - else: - dense_occupancy = dense_occupancy.long().squeeze(0) -- occupancy_label[0,dense_occupancy[:,0],dense_occupancy[:,1],dense_occupancy[:,2]]=dense_occupancy[:,3] -+ occupancy_label[0, dense_occupancy[:, 0], dense_occupancy[:, 1], dense_occupancy[:, 2]] = dense_occupancy[:, 3] - - occupancy_det_label = self.get_occupancy_det_label(voxel_coors_det, voxel_label_det, occupancy_det_label) - -- losses_seg_aux = self.lidar_seg_aux_loss(occupancy_pred.unsqueeze(0),occupancy_label) -+ losses_seg_aux = self.lidar_seg_aux_loss(occupancy_pred.unsqueeze(0), occupancy_label) - - occupancy_det_label = occupancy_det_label.reshape(-1) - occupancy_label = occupancy_label.reshape(-1) - -- assert occupancy_label.max()<=cls_num and occupancy_label.min()>=0 -+ # assert occupancy_label.max()<=cls_num and occupancy_label.min()>=0 - occupancy_pred = occupancy_pred.reshape(cls_num,-1).permute(1,0) - occupancy_det_pred = occupancy_det_pred.reshape(2,-1).permute(1,0) - -@@ -550,14 +575,10 @@ class PanoSegOccHead(DETRHead): - - all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] - all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] -- all_gt_bboxes_ignore_list = [ -- gt_bboxes_ignore for _ in range(num_dec_layers) -- ] - - losses_cls, losses_bbox = multi_apply( - self.loss_single, all_cls_scores, all_bbox_preds, -- all_gt_bboxes_list, all_gt_labels_list, -- all_gt_bboxes_ignore_list) -+ all_gt_bboxes_list, all_gt_labels_list) - - loss_dict = dict() - -@@ -566,17 +587,17 @@ class PanoSegOccHead(DETRHead): - num_total_pos = len(voxel_label) - else: - num_total_pos = len(dense_occupancy) -- num_total_neg = len(occupancy_label)-num_total_pos -+ num_total_neg = len(occupancy_label) - num_total_pos - avg_factor = num_total_pos * 1.0 + num_total_neg * self.bg_weight - if self.sync_cls_avg_factor: - avg_factor = reduce_mean( - occupancy_pred.new_tensor([avg_factor])) - avg_factor = max(avg_factor, 1) - -- losses_seg = self.lidar_seg_loss(occupancy_pred,occupancy_label,avg_factor=avg_factor) -+ losses_seg = self.lidar_seg_loss(occupancy_pred, occupancy_label, avg_factor=avg_factor) - - # Lidar det loss -- losses_det = self.get_det_loss(voxel_label_det,occupancy_det_label,occupancy_det_pred) -+ losses_det = self.get_det_loss(voxel_label_det, occupancy_det_label, occupancy_det_pred) - - # loss of proposal generated from encode feature map. - if enc_cls_scores is not None: -@@ -586,7 +607,7 @@ class PanoSegOccHead(DETRHead): - ] - enc_loss_cls, enc_losses_bbox = \ - self.loss_single(enc_cls_scores, enc_bbox_preds, -- gt_bboxes_list, binary_labels_list, gt_bboxes_ignore) -+ gt_bboxes_list, binary_labels_list) - loss_dict['enc_loss_cls'] = enc_loss_cls - loss_dict['enc_loss_bbox'] = enc_losses_bbox - -@@ -599,187 +620,35 @@ class PanoSegOccHead(DETRHead): - - # loss from other decoder layers - num_dec_layer = 0 -- for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], -- losses_bbox[:-1]): -+ for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]): - loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i - loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i - num_dec_layer += 1 - - return loss_dict - -- @force_fp32(apply_to=('preds_dicts')) -- def loss_new(self, -- gt_bboxes_list, -- gt_labels_list, -- pts_sem, -- preds_dicts, -- dense_occupancy = None, -- gt_bboxes_ignore=None, -- img_metas=None): -- """"Loss function. -- Args: -- -- gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image -- with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. -- gt_labels_list (list[Tensor]): Ground truth class indices for each -- image with shape (num_gts, ). -- preds_dicts: -- all_cls_scores (Tensor): Classification score of all -- decoder layers, has shape -- [nb_dec, bs, num_query, cls_out_channels]. -- all_bbox_preds (Tensor): Sigmoid regression -- outputs of all decode layers. Each is a 4D-tensor with -- normalized coordinate format (cx, cy, w, h) and shape -- [nb_dec, bs, num_query, 4]. -- enc_cls_scores (Tensor): Classification scores of -- points on encode feature map , has shape -- (N, h*w, num_classes). Only be passed when as_two_stage is -- True, otherwise is None. -- enc_bbox_preds (Tensor): Regression results of each points -- on the encode feature map, has shape (N, h*w, 4). Only be -- passed when as_two_stage is True, otherwise is None. -- gt_bboxes_ignore (list[Tensor], optional): Bounding boxes -- which can be ignored for each image. Default None. -- Returns: -- dict[str, Tensor]: A dictionary of loss components. -+ def voxelize(self, points, pc_range, voxel_size, pts_semantic_mask=None): - """ -- assert gt_bboxes_ignore is None, \ -- f'{self.__class__.__name__} only supports ' \ -- f'for gt_bboxes_ignore setting to None.' -- -- # GT voxel supervision -- pts = pts_sem[:,:3] -- pts_semantic_mask = pts_sem[:,3:4] -- -- pts_numpy = pts.cpu().numpy() -- pts_semantic_mask_numpy = pts_semantic_mask.cpu().numpy() -- points_grid_ind = np.floor((np.clip(pts_numpy, self.pc_range[:3],self.pc_range[3:]) - self.pc_range[:3]) / self.voxel_lidar).astype(np.int) -- label_voxel_pair = np.concatenate([points_grid_ind, pts_semantic_mask_numpy], axis=1) -- label_voxel_pair = label_voxel_pair[np.lexsort((points_grid_ind[:, 0], points_grid_ind[:, 1], points_grid_ind[:, 2])), :] -- label_voxel = torch.tensor(label_voxel_pair).to(pts.device).long() -- if dense_occupancy is None: -- pts_coors,voxelized_data,voxel_coors = self.voxelize(pts,self.pc_range,self.voxel_lidar) -- voxel_label = self.label_voxelization(pts_semantic_mask, pts_coors, voxel_coors) -- -- pts_coors_det,voxelized_data_det,voxel_coors_det = self.voxelize(pts,self.pc_range,self.voxel_det) -- voxel_label_det = self.label_voxelization(pts_semantic_mask, pts_coors_det, voxel_coors_det) -- -- all_cls_scores = preds_dicts['all_cls_scores'] -- all_bbox_preds = preds_dicts['all_bbox_preds'] -- enc_cls_scores = preds_dicts['enc_cls_scores'] -- enc_bbox_preds = preds_dicts['enc_bbox_preds'] -- occupancy = preds_dicts['occupancy'] -- occupancy_det = preds_dicts['occupancy_det'] -- -- occupancy_pred = occupancy.squeeze(0) -- occupancy_det_pred = occupancy_det.squeeze(0) -- -- cls_num,occ_z,occ_h,occ_w = occupancy_pred.shape -- if dense_occupancy is None: -- occupancy_label = (torch.ones(1,occ_z,occ_h,occ_w)*cls_num).to(occupancy_pred.device).long() -- else: -- occupancy_label = (torch.zeros(1,occ_z,occ_h,occ_w)).to(occupancy_pred.device).long() -- occupancy_det_label = (torch.ones(1,self.bev_z,self.bev_h,self.bev_w)*2).to(occupancy_det_pred.device).long() -- -- # Matrix operation acceleration -- if dense_occupancy is None: -- occupancy_label[0,label_voxel[:,2],label_voxel[:,1],label_voxel[:,0]] = label_voxel[:,3] -- else: -- dense_occupancy = dense_occupancy.long().squeeze(0) -- occupancy_label[0,dense_occupancy[:,0],dense_occupancy[:,1],dense_occupancy[:,2]]=dense_occupancy[:,3] -- -- voxel_coors_det[:,1] = voxel_coors_det[:,1].clip(min=0,max=self.bev_z-1) -- voxel_coors_det[:,2] = voxel_coors_det[:,2].clip(min=0,max=self.bev_h-1) -- voxel_coors_det[:,3] = voxel_coors_det[:,3].clip(min=0,max=self.bev_w-1) -- -- det_label_binary = ((voxel_label_det>=1)&(voxel_label_det<=10)) -- det_label = det_label_binary.long() -- occupancy_det_label[0,voxel_coors_det[:,1],voxel_coors_det[:,2],voxel_coors_det[:,3]]=det_label -- -- losses_seg_aux = self.lidar_seg_aux_loss(occupancy_pred.unsqueeze(0),occupancy_label) -- -- occupancy_det_label = occupancy_det_label.reshape(-1) -- occupancy_label = occupancy_label.reshape(-1) -- -- -- assert occupancy_label.max()<=cls_num and occupancy_label.min()>=0 -- occupancy_pred = occupancy_pred.reshape(cls_num,-1).permute(1,0) -- occupancy_det_pred = occupancy_det_pred.reshape(2,-1).permute(1,0) -- -- num_dec_layers = len(all_cls_scores) -- device = gt_labels_list[0].device -- -- gt_bboxes_list = [torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), -- dim=1).to(device) for gt_bboxes in gt_bboxes_list] -- -- all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] -- all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] -- all_gt_bboxes_ignore_list = [ -- gt_bboxes_ignore for _ in range(num_dec_layers) -- ] -- -- losses_cls, losses_bbox = multi_apply( -- self.loss_single, all_cls_scores, all_bbox_preds, -- all_gt_bboxes_list, all_gt_labels_list, -- all_gt_bboxes_ignore_list) -- -- loss_dict = dict() -- -- # Lidar seg loss -- if dense_occupancy is None: -- num_total_pos = len(voxel_label) -- else: -- num_total_pos = len(dense_occupancy) -- num_total_neg = len(occupancy_label)-num_total_pos -- avg_factor = num_total_pos * 1.0 + num_total_neg * self.bg_weight -- if self.sync_cls_avg_factor: -- avg_factor = reduce_mean( -- occupancy_pred.new_tensor([avg_factor])) -- avg_factor = max(avg_factor, 1) -- -- losses_seg = self.lidar_seg_loss(occupancy_pred,occupancy_label,avg_factor=avg_factor) -- -- # Lidar det loss -- num_total_pos_det = len(voxel_label_det) -- -- -- num_total_neg_det = len(occupancy_det_label)-num_total_pos_det -- avg_factor_det = num_total_pos_det * 1.0 + num_total_neg_det * self.bg_weight -- if self.sync_cls_avg_factor: -- avg_factor_det = reduce_mean( -- occupancy_det_pred.new_tensor([avg_factor_det])) -- avg_factor_det = max(avg_factor_det, 1) -+ Input: -+ points [N, 3], (x, y, z) -+ point_cloud_range [6], [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0], (-x, -y, -z, x, y, z) -+ voxelization_size [3], e.g. [0.256, 0.256, 0.125] - -- losses_det = self.lidar_det_loss(occupancy_det_pred,occupancy_det_label,avg_factor=avg_factor_det) -+ Output: -+ coors [N,4], (0, z, y, x) -+ unq_coors [M,4], (0, z, y, x) - -- # loss of proposal generated from encode feature map. -- if enc_cls_scores is not None: -- binary_labels_list = [ -- torch.zeros_like(gt_labels_list[i]) -- for i in range(len(all_gt_labels_list)) -- ] -- enc_loss_cls, enc_losses_bbox = \ -- self.loss_single(enc_cls_scores, enc_bbox_preds, -- gt_bboxes_list, binary_labels_list, gt_bboxes_ignore) -- loss_dict['enc_loss_cls'] = enc_loss_cls -- loss_dict['enc_loss_bbox'] = enc_losses_bbox -+ """ - -- # loss from the last decoder layer -- loss_dict['loss_cls'] = losses_cls[-1] -- loss_dict['loss_bbox'] = losses_bbox[-1] -- loss_dict['loss_seg'] = losses_seg -- loss_dict['loss_det'] = losses_det -- loss_dict['loss_seg_aux'] = losses_seg_aux -+ coors = torch.div(points[:, :3] - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').to(torch.int32) - -- # loss from other decoder layers -- num_dec_layer = 0 -- for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], -- losses_bbox[:-1]): -- loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i -- loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i -- num_dec_layer += 1 -+ unq_coors, unq_inv = custom_unique_n3(coors, return_inverse=True, return_counts=False, dim=0) - -- return loss_dict -+ if pts_semantic_mask is not None: -+ with torch.no_grad(): -+ voxel_label_my, _ = mx_driving.scatter_max(pts_semantic_mask, unq_inv.to(torch.int32)) -+ return coors[:, [2, 1, 0]].long(), unq_coors.long(), voxel_label_my.squeeze(-1).long() -+ return coors[:, [2, 1, 0]].long(), unq_coors.long() - - @force_fp32(apply_to=('preds_dicts')) - def get_bboxes(self, preds_dicts, img_metas, rescale=False): -@@ -810,190 +679,58 @@ class PanoSegOccHead(DETRHead): - - return ret_list - -- def decode_lidar_seg(self,points,occupancy): -+ def decode_lidar_seg(self, points, occupancy): - -- pts_coors,voxelized_data,voxel_coors = self.voxelize(points,self.pc_range,self.voxel_lidar) -+ pts_coors, _ = self.voxelize(points, self.pc_range, self.voxel_lidar) - - # clip out-ranged points -- z_max = int((self.pc_range[5]-self.pc_range[2])/self.voxel_lidar[2])-1 -- y_max = int((self.pc_range[4]-self.pc_range[1])/self.voxel_lidar[1])-1 -- x_max = int((self.pc_range[3]-self.pc_range[0])/self.voxel_lidar[0])-1 -- -- # valid_mask = (pts_coors[:,1].cpu().numpy()>=0) & (pts_coors[:,1].cpu().numpy()<=z_max) \ -- # & (pts_coors[:,2].cpu().numpy()>=0) & (pts_coors[:,2].cpu().numpy()<=y_max) \ -- # & (pts_coors[:,3].cpu().numpy()>=0) & (pts_coors[:,3].cpu().numpy()<=x_max) -+ z_max = int((self.pc_range[5] - self.pc_range[2]) / self.voxel_lidar[2]) - 1 -+ y_max = int((self.pc_range[4] - self.pc_range[1]) / self.voxel_lidar[1]) - 1 -+ x_max = int((self.pc_range[3] - self.pc_range[0]) / self.voxel_lidar[0]) - 1 - -- pts_coors[:,1] = pts_coors[:,1].clip(min=0,max=z_max) -- pts_coors[:,2] = pts_coors[:,2].clip(min=0,max=y_max) -- pts_coors[:,3] = pts_coors[:,3].clip(min=0,max=x_max) -- -- pts_pred = occupancy[:,:,pts_coors[:,1],pts_coors[:,2],pts_coors[:,3]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu().numpy() -+ pts_coors[:, 0] = pts_coors[:, 0].clip(min=0, max=z_max) -+ pts_coors[:, 1] = pts_coors[:, 1].clip(min=0, max=y_max) -+ pts_coors[:, 2] = pts_coors[:, 2].clip(min=0, max=x_max) - -- # pts_pred[valid_mask==False]=15 -+ pts_pred = occupancy[:, :, pts_coors[:, 0], pts_coors[:, 1], pts_coors[:, 2]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu().numpy() - - return pts_pred - -- def voxelize(self, points,point_cloud_range,voxelization_size): -- """ -- Input: -- points -- -- Output: -- coors [N,4] -- voxelized_data [M,3] -- voxel_coors [M,4] -- -- """ -- voxel_size = torch.tensor(voxelization_size, device=points.device) -- pc_range = torch.tensor(point_cloud_range, device=points.device) -- coors = torch.div(points[:, :3] - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').long() -- coors = coors[:, [2, 1, 0]] # to zyx order -- -- new_coors, unq_inv = torch.unique(coors, return_inverse=True, return_counts=False, dim=0) -- -- voxelized_data, voxel_coors = scatter_v2(points, coors, mode='avg', return_inv=False, new_coors=new_coors, unq_inv=unq_inv) -- -- batch_idx_pts = torch.zeros(coors.size(0),1).to(device=points.device) -- batch_idx_vox = torch.zeros(voxel_coors.size(0),1).to(device=points.device) -- -- coors_batch = torch.cat([batch_idx_pts,coors],dim=1) -- voxel_coors_batch = torch.cat([batch_idx_vox,voxel_coors],dim=1) -- -- return coors_batch.long(),voxelized_data,voxel_coors_batch.long() -- - def decode_lidar_seg_hr(self,points,occupancy): - - out_h = 512 - out_w = 512 - out_z = 160 - -- self.voxel_lidar = [102.4/out_h,102.4/out_w,8/out_z] -+ self.voxel_lidar = [102.4/out_h, 102.4/out_w, 8/out_z] - -- pts_coors,voxelized_data,voxel_coors = self.voxelize(points,self.pc_range,self.voxel_lidar) -+ pts_coors, _ = self.voxelize(points, self.pc_range, self.voxel_lidar) - - # clip out-ranged points -- z_max = int((self.pc_range[5]-self.pc_range[2])/self.voxel_lidar[2])-1 -- y_max = int((self.pc_range[4]-self.pc_range[1])/self.voxel_lidar[1])-1 -- x_max = int((self.pc_range[3]-self.pc_range[0])/self.voxel_lidar[0])-1 -- pts_coors[:,1] = pts_coors[:,1].clip(min=0,max=z_max) -- pts_coors[:,2] = pts_coors[:,2].clip(min=0,max=y_max) -- pts_coors[:,3] = pts_coors[:,3].clip(min=0,max=x_max) -+ z_max = int((self.pc_range[5] - self.pc_range[2]) / self.voxel_lidar[2]) - 1 -+ y_max = int((self.pc_range[4] - self.pc_range[1]) / self.voxel_lidar[1]) - 1 -+ x_max = int((self.pc_range[3] - self.pc_range[0]) / self.voxel_lidar[0]) - 1 -+ pts_coors[:, 0] = pts_coors[:, 0].clip(min=0, max=z_max) -+ pts_coors[:, 1] = pts_coors[:, 1].clip(min=0, max=y_max) -+ pts_coors[:, 2] = pts_coors[:, 2].clip(min=0, max=x_max) - - -- new_h = torch.linspace(-1, 1, out_h).view(1,out_h,1).expand(out_z,out_h,out_w) -- new_w = torch.linspace(-1, 1, out_w).view(1,1,out_w).expand(out_z,out_h,out_w) -- new_z = torch.linspace(-1, 1, out_z).view(out_z,1,1).expand(out_z,out_h,out_w) -+ new_h = torch.linspace(-1, 1, out_h).view(1, out_h, 1).expand(out_z, out_h, out_w) -+ new_w = torch.linspace(-1, 1, out_w).view(1, 1, out_w).expand(out_z, out_h, out_w) -+ new_z = torch.linspace(-1, 1, out_z).view(out_z, 1, 1).expand(out_z, out_h, out_w) - -- grid = torch.cat((new_w.unsqueeze(3),new_h.unsqueeze(3), new_z.unsqueeze(3)), dim=-1) -+ grid = torch.cat((new_w.unsqueeze(3), new_h.unsqueeze(3), new_z.unsqueeze(3)), dim=-1) - - grid = grid.unsqueeze(0).to(occupancy.device) - -+ torch.npu.set_compile_mode(jit_compile=True) - out_logit = F.grid_sample(occupancy, grid=grid) -+ torch.npu.set_compile_mode(jit_compile=False) - -- pts_pred = out_logit[:,:,pts_coors[:,1],pts_coors[:,2],pts_coors[:,3]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu().numpy() -+ pts_pred = out_logit[:, :, pts_coors[:, 0], pts_coors[:, 1], pts_coors[:, 2]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu().numpy() - return pts_pred - -- def decode_occupancy(self,points,occupancy): -- out_h = 400 -- out_w = 400 -- out_z = 64 -- self.voxel_lidar = [102.4/out_h,102.4/out_w,8/out_z] -- -- pts_coors,voxelized_data,voxel_coors = self.voxelize(points,self.pc_range,self.voxel_lidar) -- -- -- # clip out-ranged points -- z_max = int((self.pc_range[5]-self.pc_range[2])/self.voxel_lidar[2])-1 -- y_max = int((self.pc_range[4]-self.pc_range[1])/self.voxel_lidar[1])-1 -- x_max = int((self.pc_range[3]-self.pc_range[0])/self.voxel_lidar[0])-1 -- pts_coors[:,1] = pts_coors[:,1].clip(min=0,max=z_max) -- pts_coors[:,2] = pts_coors[:,2].clip(min=0,max=y_max) -- pts_coors[:,3] = pts_coors[:,3].clip(min=0,max=x_max) -- -- -- new_h = torch.linspace(-1, 1, out_h).view(1,out_h,1).expand(out_z,out_h,out_w) -- new_w = torch.linspace(-1, 1, out_w).view(1,1,out_w).expand(out_z,out_h,out_w) -- new_z = torch.linspace(-1, 1, out_z).view(out_z,1,1).expand(out_z,out_h,out_w) -- -- grid = torch.cat((new_w.unsqueeze(3),new_h.unsqueeze(3), new_z.unsqueeze(3)), dim=-1) -- -- grid = grid.unsqueeze(0).to(occupancy.device) -- -- out_logit = F.grid_sample(occupancy, grid=grid) -- -- # Occupancy Visualize -- out_class = out_logit.sigmoid()>0.2 -- all_index = out_class.sum(dim=1).nonzero() -- -- out_voxel = out_logit[:,:,all_index[:,1],all_index[:,2],all_index[:,3]] -- out_voxel_scores = out_voxel.sigmoid() -- out_voxel_confidence,out_voxel_labels = out_voxel_scores.max(dim=1) -- output_occupancy = torch.cat((all_index.unsqueeze(0),out_voxel_labels.unsqueeze(-1)),dim=-1).cpu().numpy()[...,1:] -- -- return output_occupancy -- - def decode_lidar_seg_dense(self, dense, occupancy): - dense = dense.long() -- pts_pred = occupancy[:,:,dense[0,:,0],dense[0,:,1],dense[0,:,2]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu().numpy() -+ pts_pred = occupancy[:, :, dense[0, :, 0], dense[0, :, 1], dense[0, :, 2]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu().numpy() - return pts_pred -- -- @torch.no_grad() -- def label_voxelization(self, pts_semantic_mask, pts_coors, voxel_coors): -- mask = pts_semantic_mask -- assert mask.size(0) == pts_coors.size(0) -- -- pts_coors_cls = torch.cat([pts_coors, mask], dim=1) #[N, 5] -- unq_coors_cls, unq_inv, unq_cnt = torch.unique(pts_coors_cls, return_inverse=True, return_counts=True, dim=0) #[N1, 5], [N], [N1] -- -- unq_coors, unq_inv_2, _ = torch.unique(unq_coors_cls[:, :4], return_inverse=True, return_counts=True, dim=0) #[N2, 4], [N1], [N2,] -- max_num, max_inds = torch_scatter.scatter_max(unq_cnt.float()[:,None], unq_inv_2, dim=0) #[N2, 1], [N2, 1] -- -- cls_of_max_num = unq_coors_cls[:, -1][max_inds.reshape(-1)] #[N2,] -- cls_of_max_num_N1 = cls_of_max_num[unq_inv_2] #[N1] -- cls_of_max_num_at_pts = cls_of_max_num_N1[unq_inv] #[N] -- -- assert cls_of_max_num_at_pts.size(0) == mask.size(0) -- -- cls_no_change = cls_of_max_num_at_pts == mask[:,0] # fix memory bug when scale up -- # cls_no_change = cls_of_max_num_at_pts == mask -- assert cls_no_change.any() -- -- max_pts_coors = pts_coors.max(0)[0] -- max_voxel_coors = voxel_coors.max(0)[0] -- assert (max_voxel_coors <= max_pts_coors).all() -- bsz, num_win_z, num_win_y, num_win_x = \ -- int(max_pts_coors[0].item() + 1), int(max_pts_coors[1].item() + 1), int(max_pts_coors[2].item() + 1), int(max_pts_coors[3].item() + 1) -- -- canvas = -pts_coors.new_ones((bsz, num_win_z, num_win_y, num_win_x)) -- -- canvas[pts_coors[:, 0], pts_coors[:, 1], pts_coors[:, 2], pts_coors[:, 3]] = \ -- torch.arange(pts_coors.size(0), dtype=pts_coors.dtype, device=pts_coors.device) -- -- fetch_inds_of_points = canvas[voxel_coors[:, 0], voxel_coors[:, 1], voxel_coors[:, 2], voxel_coors[:, 3]] -- -- assert (fetch_inds_of_points >= 0).all(), '-1 should not be in it.' -- -- voxel_label = cls_of_max_num_at_pts[fetch_inds_of_points] -- -- voxel_label = torch.clamp(voxel_label,min=0).long() -- -- return voxel_label -- -- @torch.no_grad() -- def get_point_pred(self,occupancy,pts_coors,voxel_coors,voxel_label,pts_semantic_mask): -- -- voxel_pred = occupancy[:,:,voxel_coors[:,1],voxel_coors[:,2],voxel_coors[:,3]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu() -- -- voxel_gt = voxel_label.long().cpu() -- -- accurate = voxel_pred==voxel_gt -- -- acc = accurate.sum()/len(voxel_gt) -- -- pts_pred = occupancy[:,:,pts_coors[:,1],pts_coors[:,2],pts_coors[:,3]].squeeze(0).softmax(dim=0).argmax(dim=0).cpu() -- pts_gt = pts_semantic_mask.long().squeeze(1).cpu() -- -- pts_accurate = pts_pred==pts_gt -- pts_acc = pts_accurate.sum()/len(pts_gt) -- -- return pts_acc -diff --git a/projects/mmdet3d_plugin/bevformer/detectors/__init__.py b/projects/mmdet3d_plugin/bevformer/detectors/__init__.py -index 1012ef3..bf7f763 100644 ---- a/projects/mmdet3d_plugin/bevformer/detectors/__init__.py -+++ b/projects/mmdet3d_plugin/bevformer/detectors/__init__.py -@@ -1,3 +1,2 @@ - from .pano_occ import PanoOcc --from .panoseg_occ import PanoSegOcc --from .panoseg_occ_sparse import PanoSegOccSparse -\ No newline at end of file -+from .panoseg_occ import PanoSegOcc -\ No newline at end of file -diff --git a/projects/mmdet3d_plugin/bevformer/detectors/pano_occ.py b/projects/mmdet3d_plugin/bevformer/detectors/pano_occ.py -index 46a8b99..4d8cc21 100644 ---- a/projects/mmdet3d_plugin/bevformer/detectors/pano_occ.py -+++ b/projects/mmdet3d_plugin/bevformer/detectors/pano_occ.py -@@ -128,10 +128,6 @@ class PanoOcc(MVXTwoStageDetector): - losses = self.pts_bbox_head.loss(*loss_inputs, img_metas=img_metas) - return losses - -- def forward_dummy(self, img): -- dummy_metas = None -- return self.forward_test(img=img, img_metas=[[dummy_metas]]) -- - def forward(self, return_loss=True, **kwargs): - """Calls either forward_train or forward_test depending on whether - return_loss=True. -diff --git a/projects/mmdet3d_plugin/bevformer/modules/__init__.py b/projects/mmdet3d_plugin/bevformer/modules/__init__.py -index 17ded68..f880296 100644 ---- a/projects/mmdet3d_plugin/bevformer/modules/__init__.py -+++ b/projects/mmdet3d_plugin/bevformer/modules/__init__.py -@@ -7,12 +7,10 @@ from .decoder import DetectionTransformerDecoder - from .occ_temporal_attention import OccTemporalAttention - from .occ_spatial_attention import OccSpatialAttention - from .occ_decoder import OccupancyDecoder --from .occ_mlp_decoder import MLP_Decoder, SparseMLPDecoder -+from .occ_mlp_decoder import MLP_Decoder - from .occ_temporal_encoder import OccTemporalEncoder - from .transformer_occ import TransformerOcc - from .occ_voxel_decoder import VoxelDecoder - from .pano_transformer_occ import PanoOccTransformer - from .panoseg_transformer_occ import PanoSegOccTransformer --from .occ_voxel_seg_decoder import VoxelNaiveDecoder --from .sparse_occ_decoder import SparseOccupancyDecoder --from .sparse_occ_transformer import SparseOccupancyTransformer -\ No newline at end of file -+from .occ_voxel_seg_decoder import VoxelNaiveDecoder -\ No newline at end of file -diff --git a/projects/mmdet3d_plugin/bevformer/modules/decoder.py b/projects/mmdet3d_plugin/bevformer/modules/decoder.py -index 33024f8..21f8b51 100644 ---- a/projects/mmdet3d_plugin/bevformer/modules/decoder.py -+++ b/projects/mmdet3d_plugin/bevformer/modules/decoder.py -@@ -23,12 +23,7 @@ from mmcv.runner.base_module import BaseModule, ModuleList, Sequential - from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning, - to_2tuple) - --from mmcv.utils import ext_loader --from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \ -- MultiScaleDeformableAttnFunction_fp16 -- --ext_module = ext_loader.load_ext( -- '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) -+from mx_driving import multi_scale_deformable_attn - - - def inverse_sigmoid(x, eps=1e-5): -@@ -323,15 +318,8 @@ class CustomMSDeformableAttention(BaseModule): - f'Last dim of reference_points must be' - f' 2 or 4, but get {reference_points.shape[-1]} instead.') - if torch.cuda.is_available() and value.is_cuda: -- -- # using fp16 deformable attention is unstable because it performs many sum operations -- if value.dtype == torch.float16: -- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 -- else: -- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 -- output = MultiScaleDeformableAttnFunction.apply( -- value, spatial_shapes, level_start_index, sampling_locations, -- attention_weights, self.im2col_step) -+ output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, -+ sampling_locations, attention_weights) - else: - output = multi_scale_deformable_attn_pytorch( - value, spatial_shapes, sampling_locations, attention_weights) -diff --git a/projects/mmdet3d_plugin/bevformer/modules/occ_decoder.py b/projects/mmdet3d_plugin/bevformer/modules/occ_decoder.py -index 15058e4..e4caa64 100644 ---- a/projects/mmdet3d_plugin/bevformer/modules/occ_decoder.py -+++ b/projects/mmdet3d_plugin/bevformer/modules/occ_decoder.py -@@ -4,6 +4,23 @@ from mmcv.cnn.bricks.registry import TRANSFORMER_LAYER_SEQUENCE - import torch.nn.functional as F - - -+def interpolate_trilinear(x, scale_factor, mode, align_corners): -+ # assert mode == 'trilinear' -+ # assert align_corners == False -+ # bilinear + bilinear -+ scale_t, scale_h, scale_w = scale_factor -+ N, C, T, H, W = x.size(0), x.size(1), x.size(2), x.size(3), x.size(4) -+ -+ x_fused_nc = x.reshape(N*C, T, H, W) -+ y_resize_hw = F.interpolate(x_fused_nc, scale_factor=(scale_h, scale_w), mode='bilinear') -+ new_shape_h, new_shape_w = y_resize_hw.shape[-2], y_resize_hw.shape[-1] -+ y_fused_hw = y_resize_hw.reshape(N, C, T, new_shape_h*new_shape_w) -+ y_resize_t = F.interpolate(y_fused_hw, scale_factor=(scale_t, 1), mode='bilinear') -+ new_shape_t = y_resize_t.shape[-2] -+ y = y_resize_t.reshape(N, C, new_shape_t, new_shape_h, new_shape_w) -+ return y -+ -+ - @TRANSFORMER_LAYER_SEQUENCE.register_module() - class OccupancyDecoder(BaseModule): - -@@ -66,6 +83,8 @@ class OccupancyDecoder(BaseModule): - - voxel_cls = self.semantic_cls(voxel_up1) - -- voxel_pred = F.interpolate(voxel_cls,scale_factor=(self.inter_up_rate[0],self.inter_up_rate[1],self.inter_up_rate[2]),mode=self.upsampling_method,align_corners=self.align_corners) -+ voxel_pred = interpolate_trilinear(voxel_cls, -+ scale_factor=(self.inter_up_rate[0], self.inter_up_rate[1], self.inter_up_rate[2]), -+ mode=self.upsampling_method, align_corners=self.align_corners) - - return voxel_pred, voxel_det -\ No newline at end of file -diff --git a/projects/mmdet3d_plugin/bevformer/modules/occ_mlp_decoder.py b/projects/mmdet3d_plugin/bevformer/modules/occ_mlp_decoder.py -index 615e26b..952fdce 100644 ---- a/projects/mmdet3d_plugin/bevformer/modules/occ_mlp_decoder.py -+++ b/projects/mmdet3d_plugin/bevformer/modules/occ_mlp_decoder.py -@@ -4,6 +4,24 @@ from mmcv.cnn.bricks.registry import TRANSFORMER_LAYER_SEQUENCE - import torch.nn.functional as F - import torch - -+ -+def interpolate_trilinear(x, scale_factor, mode, align_corners): -+ # assert mode == 'trilinear' -+ # assert align_corners == False -+ # bilinear + bilinear -+ scale_t, scale_h, scale_w = scale_factor -+ N, C, T, H, W = x.size(0), x.size(1), x.size(2), x.size(3), x.size(4) -+ -+ x_fused_nc = x.reshape(N*C, T, H, W) -+ y_resize_hw = F.interpolate(x_fused_nc, scale_factor=(scale_h, scale_w), mode='bilinear') -+ new_shape_h, new_shape_w = y_resize_hw.shape[-2], y_resize_hw.shape[-1] -+ y_fused_hw = y_resize_hw.reshape(N, C, T, new_shape_h*new_shape_w) -+ y_resize_t = F.interpolate(y_fused_hw, scale_factor=(scale_t, 1), mode='bilinear') -+ new_shape_t = y_resize_t.shape[-2] -+ y = y_resize_t.reshape(N, C, new_shape_t, new_shape_h, new_shape_w) -+ return y -+ -+ - @TRANSFORMER_LAYER_SEQUENCE.register_module() - class MLP_Decoder(BaseModule): - -@@ -32,7 +50,9 @@ class MLP_Decoder(BaseModule): - - voxel_point_cls = point_cls.view(1,inputs.shape[2],inputs.shape[3],inputs.shape[4],-1).permute(0,4,1,2,3) - -- voxel_logits = F.interpolate(voxel_point_cls,scale_factor=(self.inter_up_rate[0],self.inter_up_rate[1],self.inter_up_rate[2]),mode=self.upsampling_method,align_corners=self.align_corners) -+ voxel_logits = interpolate_trilinear(voxel_point_cls, -+ scale_factor=(self.inter_up_rate[0], self.inter_up_rate[1], self.inter_up_rate[2]), -+ mode=self.upsampling_method, align_corners=self.align_corners) - - return voxel_logits - -diff --git a/projects/mmdet3d_plugin/bevformer/modules/occ_temporal_attention.py b/projects/mmdet3d_plugin/bevformer/modules/occ_temporal_attention.py -index 8f62f3a..8236e93 100644 ---- a/projects/mmdet3d_plugin/bevformer/modules/occ_temporal_attention.py -+++ b/projects/mmdet3d_plugin/bevformer/modules/occ_temporal_attention.py -@@ -7,6 +7,7 @@ - from projects.mmdet3d_plugin.models.utils.bricks import run_time - from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32 - from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch -+from mx_driving import multi_scale_deformable_attn - import warnings - import torch - import torch.nn as nn -@@ -243,15 +244,8 @@ class OccTemporalAttention(BaseModule): - - sampling_locations = sampling_locations.contiguous() - if torch.cuda.is_available() and value.is_cuda: -- -- # using fp16 deformable attention is unstable because it performs many sum operations -- if value.dtype == torch.float16: -- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 -- else: -- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 -- output = MultiScaleDeformableAttnFunction.apply( -- value, spatial_shapes, level_start_index, sampling_locations, -- attention_weights, self.im2col_step) -+ output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, -+ sampling_locations, attention_weights) - else: - - output = multi_scale_deformable_attn_pytorch( -diff --git a/projects/mmdet3d_plugin/bevformer/modules/panoseg_transformer_occ.py b/projects/mmdet3d_plugin/bevformer/modules/panoseg_transformer_occ.py -index be6c6ed..5604546 100644 ---- a/projects/mmdet3d_plugin/bevformer/modules/panoseg_transformer_occ.py -+++ b/projects/mmdet3d_plugin/bevformer/modules/panoseg_transformer_occ.py -@@ -206,10 +206,12 @@ class PanoSegOccTransformer(BaseModule): - - curr_grid_in_prev_frame = torch.stack(curr_grid_in_prev_frame_lst, dim=0) - -+ torch.npu.set_compile_mode(jit_compile=True) - prev_bev_warp_to_curr_frame = nn.functional.grid_sample( - prev_bev[i].permute(0, 1, 4, 2, 3), # [bs, dim, z, h, w] - curr_grid_in_prev_frame.permute(0, 3, 1, 2, 4), # [bs, z, h, w, 3] - align_corners=False) -+ torch.npu.set_compile_mode(jit_compile=False) - prev_bev = prev_bev_warp_to_curr_frame.permute(0, 1, 3, 4, 2).unsqueeze(0) # add bs dim, [bs, dim, h, w, z] - - return prev_bev -diff --git a/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py b/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py -index b53b66c..dfb7e1f 100644 ---- a/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py -+++ b/projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py -@@ -5,27 +5,27 @@ - # Modified by Zhiqi Li - # --------------------------------------------- - --from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch -+import math - import warnings -+ - import torch - import torch.nn as nn - import torch.nn.functional as F --from mmcv.cnn import xavier_init, constant_init --from mmcv.cnn.bricks.registry import (ATTENTION, -- TRANSFORMER_LAYER, -- TRANSFORMER_LAYER_SEQUENCE) -+from mmcv.cnn import constant_init, xavier_init -+from mmcv.cnn.bricks.registry import ATTENTION, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE - from mmcv.cnn.bricks.transformer import build_attention --import math --from mmcv.runner import force_fp32, auto_fp16 -- -+from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch -+from mmcv.runner import force_fp32 - from mmcv.runner.base_module import BaseModule, ModuleList, Sequential -- --from mmcv.utils import ext_loader --from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \ -- MultiScaleDeformableAttnFunction_fp16 - from projects.mmdet3d_plugin.models.utils.bricks import run_time --ext_module = ext_loader.load_ext( -- '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) -+ -+from mx_driving import multi_scale_deformable_attn -+ -+ -+indexes_global = None -+max_len_global = None -+bev_mask_id_global = -1 -+count_global = None - - - @ATTENTION.register_module() -@@ -135,10 +135,27 @@ class SpatialCrossAttention(BaseModule): - # bevformer reference_points_cam shape: (num_cam,bs,h*w,num_points_in_pillar,2) - D = reference_points_cam.size(3) - indexes = [] -- for i, mask_per_img in enumerate(bev_mask): -- index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1) -- indexes.append(index_query_per_img) -- max_len = max([len(each) for each in indexes]) -+ global indexes_global, max_len_global, bev_mask_id_global, count_global -+ bev_mask_id = id(bev_mask) -+ if bev_mask_id == bev_mask_id_global: -+ indexes = indexes_global -+ max_len = max_len_global -+ count = count_global -+ else: -+ count = torch.any(bev_mask, 3) -+ bev_mask_ = count.squeeze() -+ for i, mask_per_img in enumerate(bev_mask_): -+ index_query_per_img = mask_per_img.nonzero().squeeze(-1) -+ indexes.append(index_query_per_img) -+ -+ max_len = max([len(each) for each in indexes]) -+ count = count.permute(1, 2, 0).sum(-1) -+ count = torch.clamp(count, min=1.0) -+ count = count[..., None] -+ count_global = count -+ indexes_global = indexes -+ max_len_global = max_len -+ bev_mask_id_global = bev_mask_id - - # each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory. - queries_rebatch = query.new_zeros( -@@ -146,9 +163,9 @@ class SpatialCrossAttention(BaseModule): - reference_points_rebatch = reference_points_cam.new_zeros( - [bs, self.num_cams, max_len, D, 2]) - -- for j in range(bs): -- for i, reference_points_per_img in enumerate(reference_points_cam): -- index_query_per_img = indexes[i] -+ for i, reference_points_per_img in enumerate(reference_points_cam): -+ index_query_per_img = indexes[i] -+ for j in range(bs): - queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img] - reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img] - -@@ -159,17 +176,15 @@ class SpatialCrossAttention(BaseModule): - value = value.permute(2, 0, 1, 3).reshape( - bs * self.num_cams, l, self.embed_dims) - -- queries = self.deformable_attention(query=queries_rebatch.view(bs*self.num_cams, max_len, self.embed_dims), key=key, value=value, -- reference_points=reference_points_rebatch.view(bs*self.num_cams, max_len, D, 2), spatial_shapes=spatial_shapes, -+ queries = self.deformable_attention(query=queries_rebatch.view(bs * self.num_cams, max_len, self.embed_dims), key=key, value=value, -+ reference_points=reference_points_rebatch.view(bs * self.num_cams, max_len, D, 2), spatial_shapes=spatial_shapes, - level_start_index=level_start_index).view(bs, self.num_cams, max_len, self.embed_dims) - for j in range(bs): - for i, index_query_per_img in enumerate(indexes): - slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)] - -- count = bev_mask.sum(-1) > 0 -- count = count.permute(1, 2, 0).sum(-1) -- count = torch.clamp(count, min=1.0) -- slots = slots / count[..., None] -+ -+ slots = slots / count - slots = self.output_proj(slots) - - return self.dropout(slots) + inp_residual -@@ -329,7 +344,7 @@ class MSDeformableAttention3D(BaseModule): - - bs, num_query, _ = query.shape - bs, num_value, _ = value.shape -- assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value -+ # assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value - - value = self.value_proj(value) - if key_padding_mask is not None: -@@ -366,7 +381,7 @@ class MSDeformableAttention3D(BaseModule): - bs, num_query, num_heads, num_levels, num_all_points // num_Z_anchors, num_Z_anchors, xy) - sampling_locations = reference_points + sampling_offsets - bs, num_query, num_heads, num_levels, num_points, num_Z_anchors, xy = sampling_locations.shape -- assert num_all_points == num_points * num_Z_anchors -+ # assert num_all_points == num_points * num_Z_anchors - - sampling_locations = sampling_locations.view( - bs, num_query, num_heads, num_levels, num_all_points, xy) -@@ -379,13 +394,8 @@ class MSDeformableAttention3D(BaseModule): - f' 2 or 4, but get {reference_points.shape[-1]} instead.') - - if torch.cuda.is_available() and value.is_cuda: -- if value.dtype == torch.float16: -- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 -- else: -- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 -- output = MultiScaleDeformableAttnFunction.apply( -- value, spatial_shapes, level_start_index, sampling_locations, -- attention_weights, self.im2col_step) -+ output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, -+ sampling_locations, attention_weights) - else: - output = multi_scale_deformable_attn_pytorch( - value, spatial_shapes, sampling_locations, attention_weights) -diff --git a/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py b/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py -index 78fb9f5..61873c9 100644 ---- a/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py -+++ b/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py -@@ -7,6 +7,7 @@ - from projects.mmdet3d_plugin.models.utils.bricks import run_time - from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32 - from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch -+from mx_driving import multi_scale_deformable_attn - import warnings - import torch - import torch.nn as nn -@@ -238,15 +239,8 @@ class TemporalSelfAttention(BaseModule): - f'Last dim of reference_points must be' - f' 2 or 4, but get {reference_points.shape[-1]} instead.') - if torch.cuda.is_available() and value.is_cuda: -- -- # using fp16 deformable attention is unstable because it performs many sum operations -- if value.dtype == torch.float16: -- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 -- else: -- MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32 -- output = MultiScaleDeformableAttnFunction.apply( -- value, spatial_shapes, level_start_index, sampling_locations, -- attention_weights, self.im2col_step) -+ output = multi_scale_deformable_attn(value, spatial_shapes, level_start_index, -+ sampling_locations, attention_weights) - else: - - output = multi_scale_deformable_attn_pytorch( -diff --git a/projects/mmdet3d_plugin/datasets/builder.py b/projects/mmdet3d_plugin/datasets/builder.py -index 0ad7a92..0625933 100644 ---- a/projects/mmdet3d_plugin/datasets/builder.py -+++ b/projects/mmdet3d_plugin/datasets/builder.py -@@ -25,6 +25,7 @@ def build_dataloader(dataset, - seed=None, - shuffler_sampler=None, - nonshuffler_sampler=None, -+ pin_memory=False, - **kwargs): - """Build PyTorch DataLoader. - In distributed training, each GPU/process has a dataloader. -@@ -86,7 +87,7 @@ def build_dataloader(dataset, - sampler=sampler, - num_workers=num_workers, - collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), -- pin_memory=False, -+ pin_memory=pin_memory, - worker_init_fn=init_fn, - **kwargs) - -diff --git a/projects/mmdet3d_plugin/datasets/pipelines/compose.py b/projects/mmdet3d_plugin/datasets/pipelines/compose.py -index 08e46a8..6052444 100644 ---- a/projects/mmdet3d_plugin/datasets/pipelines/compose.py -+++ b/projects/mmdet3d_plugin/datasets/pipelines/compose.py -@@ -3,6 +3,7 @@ import collections - from mmcv.utils import build_from_cfg - - from mmdet.datasets.builder import PIPELINES -+from mmdet3d.datasets.builder import PIPELINES as PIPELINES_3d - - @PIPELINES.register_module() - class CustomCompose: -@@ -16,7 +17,10 @@ class CustomCompose: - self.transforms = [] - for transform in transforms: - if isinstance(transform, dict): -- transform = build_from_cfg(transform, PIPELINES) -+ if transform["type"] not in PIPELINES: -+ transform = build_from_cfg(transform, PIPELINES_3d) -+ else: -+ transform = build_from_cfg(transform, PIPELINES) - self.transforms.append(transform) - elif callable(transform): - self.transforms.append(transform) -diff --git a/projects/mmdet3d_plugin/models/backbones/__init__.py b/projects/mmdet3d_plugin/models/backbones/__init__.py -index f86b114..cea72f5 100755 ---- a/projects/mmdet3d_plugin/models/backbones/__init__.py -+++ b/projects/mmdet3d_plugin/models/backbones/__init__.py -@@ -1,5 +1,3 @@ - from .vovnet import VoVNet --from .internv2_impl16 import InternV2Impl16 --from .sam_modeling import ImageEncoderViT - --__all__ = ['VoVNet', "InternV2Impl16", "ImageEncoderViT"] -\ No newline at end of file -+__all__ = ['VoVNet'] -\ No newline at end of file -diff --git a/projects/mmdet3d_plugin/models/backbones/sam_modeling/__init__.py b/projects/mmdet3d_plugin/models/backbones/sam_modeling/__init__.py -index 50f3bd6..e69de29 100644 ---- a/projects/mmdet3d_plugin/models/backbones/sam_modeling/__init__.py -+++ b/projects/mmdet3d_plugin/models/backbones/sam_modeling/__init__.py -@@ -1 +0,0 @@ --from .image_encoder import ImageEncoderViT -\ No newline at end of file -diff --git a/tools/dist_test.sh b/tools/dist_test.sh -index 3e2ec30..931aa0f 100755 ---- a/tools/dist_test.sh -+++ b/tools/dist_test.sh -@@ -6,5 +6,5 @@ GPUS=$3 - PORT=${PORT:-29503} - - PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ --python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ -+torchrun --nproc_per_node=$GPUS --master_port=$PORT \ - $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} --eval bbox -diff --git a/tools/dist_test_seg.sh b/tools/dist_test_seg.sh -index 7719313..0b4f6d2 100755 ---- a/tools/dist_test_seg.sh -+++ b/tools/dist_test_seg.sh -@@ -6,5 +6,5 @@ GPUS=$3 - PORT=${PORT:-29503} - - PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ --python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ -+torchrun --nproc_per_node=$GPUS --master_port=$PORT \ - $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} --out 'seg_result.pkl' -diff --git a/tools/dist_train.sh b/tools/dist_train.sh -index cd9dd42..35e7be6 100755 ---- a/tools/dist_train.sh -+++ b/tools/dist_train.sh -@@ -2,13 +2,14 @@ - # - CONFIG=$1 - GPUS=$2 -+WORK_DIR=$3 - NNODES=${NNODES:-1} - NODE_RANK=${NODE_RANK:-0} - PORT=${PORT:-29500} - MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} - - PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ --python -m torch.distributed.launch \ -+torchrun \ - --nnodes=$NNODES \ - --node_rank=$NODE_RANK \ - --master_addr=$MASTER_ADDR \ -@@ -17,4 +18,5 @@ python -m torch.distributed.launch \ - $(dirname "$0")/train.py \ - $CONFIG \ - --seed 0 \ -- --launcher pytorch ${@:3} --deterministic 2>&1 | tee output.log -\ No newline at end of file -+ --work-dir ${WORK_DIR} \ -+ --launcher pytorch ${@:4} --deterministic -\ No newline at end of file -diff --git a/tools/test.py b/tools/test.py -index b7de8a6..ff93666 100755 ---- a/tools/test.py -+++ b/tools/test.py -@@ -1,10 +1,15 @@ -+# Copyright (c) OpenMMLab. All rights reserved. -+# Copyright 2024 Huawei Technologies Co., Ltd - import argparse - import mmcv - import os - import torch -+import torch_npu -+from torch_npu.contrib import transfer_to_npu - import warnings - from mmcv import Config, DictAction - from mmcv.cnn import fuse_conv_bn -+from mmcv.device.npu import NPUDataParallel, NPUDistributedDataParallel - from mmcv.parallel import MMDataParallel, MMDistributedDataParallel - from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, - wrap_fp16_model) -@@ -20,6 +25,9 @@ import time - import os.path as osp - from tools.eval_metrics.lidar_seg import * - -+torch.npu.config.allow_internal_format = False -+ -+ - def parse_args(): - parser = argparse.ArgumentParser( - description='MMDet test (and eval) a model') -@@ -225,11 +233,11 @@ def main(): - model.PALETTE = dataset.PALETTE - - if not distributed: -- # assert False -- model = MMDataParallel(model, device_ids=[0]) -- outputs = single_gpu_test(model, data_loader, args.show, args.show_dir) -+ assert False -+ # model = MMDataParallel(model, device_ids=[0]) -+ # outputs = single_gpu_test(model, data_loader, args.show, args.show_dir) - else: -- model = MMDistributedDataParallel( -+ model = NPUDistributedDataParallel( - model.cuda(), - device_ids=[torch.cuda.current_device()], - broadcast_buffers=False) -diff --git a/tools/train.py b/tools/train.py -index 390d37d..af6e918 100755 ---- a/tools/train.py -+++ b/tools/train.py -@@ -6,6 +6,8 @@ import mmcv - import os - import time - import torch -+import torch_npu -+from torch_npu.contrib import transfer_to_npu - import warnings - from mmcv import Config, DictAction - from mmcv.runner import get_dist_info, init_dist -@@ -22,6 +24,8 @@ from mmseg import __version__ as mmseg_version - - from mmcv.utils import TORCH_VERSION, digit_version - -+torch.npu.config.allow_internal_format = False -+ - - def parse_args(): - parser = argparse.ArgumentParser(description='Train a detector') -@@ -131,10 +135,6 @@ def main(): - # set cudnn_benchmark - if cfg.get('cudnn_benchmark', False): - torch.backends.cudnn.benchmark = True -- # set tf32 -- if cfg.get('close_tf32', False): -- torch.backends.cuda.matmul.allow_tf32 = False -- torch.backends.cudnn.allow_tf32 = False - - # work_dir is determined in this priority: CLI > segment in file > filename - if args.work_dir is not None: