From 7df01e6f4565adae61e904cab9886cf721a95913 Mon Sep 17 00:00:00 2001 From: gsoleil Date: Tue, 2 Sep 2025 12:49:36 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9QCNet=E4=B8=ADpytorch-lightni?= =?UTF-8?q?ng=E7=9A=84=E5=AE=89=E8=A3=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model_examples/QCNet/README.md | 6 +- .../QCNet/patch/torch_geometric.patch | 80 +++++++++++++++---- model_examples/QCNet/script/train.sh | 2 +- .../QCNet/script/train_performance.sh | 4 +- 4 files changed, 72 insertions(+), 20 deletions(-) diff --git a/model_examples/QCNet/README.md b/model_examples/QCNet/README.md index 7673306b..33995dc2 100644 --- a/model_examples/QCNet/README.md +++ b/model_examples/QCNet/README.md @@ -97,9 +97,9 @@ code_path=model_examples/QCNet 3. 安装pytorch_lightening ``` - git clone https://github.com/Lightning-AI/pytorch-lightning.git -b builds/2.3.1 + git clone --branch 2.3.3 https://github.com/Lightning-AI/pytorch-lightning.git cd pytorch-lightning/ - git checkout 8e39ef55142e3cf1878efee85cfbeb0ed0ce29b5 + git checkout cf348673eda662cc2e9aa71a72a19b8774f85718 git apply ../patch/lightning.patch pip install -e ./ --no-deps cd .. @@ -190,7 +190,7 @@ cd model_examples/QCNet | 芯片 | 卡数 | global batch size | epoch | minFDE | minADE | 性能-单步迭代耗时(s) | | :-----------: | :--: | :---------------: | :---: | :--------------------: | :--------------------: | :--------------: | | 竞品A | 8p | 32 | 64 | 1.259 | 0.721 | 0.34 | -| Atlas 800T A2 | 8p | 32 | 64 | 1.250 | 0.723 | 0.43 | +| Atlas 800T A2 | 8p | 32 | 64 | 1.250 | 0.723 | 0.425 | # 变更说明 diff --git a/model_examples/QCNet/patch/torch_geometric.patch b/model_examples/QCNet/patch/torch_geometric.patch index 7414e1af..719d6804 100644 --- a/model_examples/QCNet/patch/torch_geometric.patch +++ b/model_examples/QCNet/patch/torch_geometric.patch @@ -1,48 +1,100 @@ diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py -index 2db5c9893..23fa4ba07 100644 +index 2db5c9893..cb99ec80f 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py -@@ -42,6 +42,7 @@ from torch_geometric.utils import ( - to_edge_index, +@@ -20,7 +20,6 @@ from uuid import uuid1 + import torch + from torch import Tensor + from torch.utils.hooks import RemovableHandle +- + from torch_geometric.nn.aggr import Aggregation, MultiAggregation + from torch_geometric.nn.conv.utils.inspector import ( + Inspector, +@@ -43,6 +42,18 @@ from torch_geometric.utils import ( ) from torch_geometric.utils.sparse import ptr2index -+from mx_driving import npu_index_select ++import torch_npu ++import ctypes ++try: ++ from mx_driving import npu_index_select ++ npu_index_select_available = True ++except ImportError: ++ npu_index_select_available = False ++ ++lib_path = os.path.join(os.environ.get('ASCEND_HOME_PATH', '/usr/local/Ascend'), 'lib64', 'libopapi.so') ++lib = ctypes.CDLL(lib_path) ++aclnnIndexAddV2_available = hasattr(lib, "aclnnIndexAddV2") ++index_select_op = None FUSE_AGGRS = {'add', 'sum', 'mean', 'min', 'max'} -@@ -264,12 +265,16 @@ class MessagePassing(torch.nn.Module): + +@@ -190,6 +201,16 @@ class MessagePassing(torch.nn.Module): + self._edge_update_forward_pre_hooks = OrderedDict() + self._edge_update_forward_hooks = OrderedDict() + ++ global index_select_op ++ ++ # 检查是否已初始化 ++ if index_select_op is not None: ++ return ++ if npu_index_select_available and aclnnIndexAddV2_available: ++ index_select_op = npu_index_select ++ else: ++ index_select_op = torch_npu.index_select ++ + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + if self.aggr_module is not None: +@@ -264,12 +285,24 @@ class MessagePassing(torch.nn.Module): else: raise ValueError(f"Unsupported sparse tensor layout " f"(got '{edge_index.layout}')") - return src.index_select(self.node_dim, index) + tmp = src.reshape(len(src), -1) -+ x = npu_index_select(tmp, self.node_dim, index) -+ return x.reshape(len(x), 8, -1) ++ x = index_select_op(tmp, self.node_dim, index) ++ if len(x) == 0: ++ x_shape = x.shape[1] ++ return x.reshape(len(x), 8, x_shape//8) ++ else: ++ return x.reshape(len(x), 8, -1) elif isinstance(edge_index, Tensor): try: index = edge_index[dim] - return src.index_select(self.node_dim, index) + tmp = src.reshape(len(src), -1) -+ x = npu_index_select(tmp, self.node_dim, index) -+ return x.reshape(len(x), 8, -1) ++ x = index_select_op(tmp, self.node_dim, index) ++ if len(x) == 0: ++ x_shape = x.shape[1] ++ return x.reshape(len(x), 8, x_shape//8) ++ else: ++ return x.reshape(len(x), 8, -1) except (IndexError, RuntimeError) as e: if index.min() < 0 or index.max() >= src.size(self.node_dim): raise IndexError( -@@ -304,10 +309,14 @@ class MessagePassing(torch.nn.Module): +@@ -304,10 +337,22 @@ class MessagePassing(torch.nn.Module): elif isinstance(edge_index, SparseTensor): if dim == 0: col = edge_index.storage.col() - return src.index_select(self.node_dim, col) + tmp = src.reshape(len(src), -1) -+ x = npu_index_select(tmp, self.node_dim, col) -+ return x.reshape(len(x), 8, -1) ++ x = index_select_op(tmp, self.node_dim, col) ++ if len(x) == 0: ++ x_shape = x.shape[1] ++ return x.reshape(len(x), 8, x_shape//8) ++ else: ++ return x.reshape(len(x), 8, -1) elif dim == 1: row = edge_index.storage.row() - return src.index_select(self.node_dim, row) + tmp = src.reshape(len(src), -1) -+ x = npu_index_select(tmp, self.node_dim, row) -+ return x.reshape(len(x), 8, -1) ++ x = index_select_op(tmp, self.node_dim, row) ++ if len(x) == 0: ++ x_shape = x.shape[1] ++ return x.reshape(len(x), 8, x_shape//8) ++ else: ++ return x.reshape(len(x), 8, -1) raise ValueError( ('`MessagePassing.propagate` only supports integer tensors of ' diff --git a/model_examples/QCNet/script/train.sh b/model_examples/QCNet/script/train.sh index 76cf550a..bac3c5ab 100644 --- a/model_examples/QCNet/script/train.sh +++ b/model_examples/QCNet/script/train.sh @@ -13,7 +13,7 @@ export ACLNN_CACHE_LIMIT=500000 cur_path=$(pwd) ASCEND_DEVICE_ID=0 -log_file="$cur_path/test/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log" +log_file="$cur_path/train_${ASCEND_DEVICE_ID}.log" # /path/to/datasets 请更改为存放数据的路径 python QCNet/train_qcnet.py --root /path/to/datasets --train_batch_size 4 \ diff --git a/model_examples/QCNet/script/train_performance.sh b/model_examples/QCNet/script/train_performance.sh index 6a7d557a..545af298 100644 --- a/model_examples/QCNet/script/train_performance.sh +++ b/model_examples/QCNet/script/train_performance.sh @@ -13,7 +13,7 @@ export ACLNN_CACHE_LIMIT=500000 cur_path=$(pwd) ASCEND_DEVICE_ID=0 -log_file="$cur_path/test/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log" +log_file="$cur_path/train_${ASCEND_DEVICE_ID}.log" # /path/to/datasets 请更改为存放数据的路径 python QCNet/train_qcnet.py --root /path/to/datasets --train_batch_size 4 \ @@ -22,5 +22,5 @@ python QCNet/train_qcnet.py --root /path/to/datasets --train_batch_size 4 \ --pl2pl_radius 150 --time_span 10 --pl2a_radius 50 --a2a_radius 50 \ --num_t2m_steps 30 --pl2m_radius 150 --a2m_radius 150 --T_max 1 --max_epochs 1 >$log_file 2>&1 -final_epoch_time=$(tac "$log_file" | grep -m1 "Average Training Time" | grep -oP 'Average Training Time \K\d+\.\d+') +final_epoch_time=$(tac "$log_file" | grep -m1 "Average Training Time" | grep -oP 'Average Training Time.*: \K\d+\.\d+') echo "FPS: ${final_epoch_time}s" -- Gitee