From 21ab90b2b6cd5c4609b64a87bf441f7762657d46 Mon Sep 17 00:00:00 2001 From: gsoleil Date: Mon, 8 Sep 2025 16:23:52 +0800 Subject: [PATCH] =?UTF-8?q?QCNet=E6=A8=A1=E5=9E=8B=E5=8A=A8=E6=80=81?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=8A=A0=E8=BD=BD=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model_examples/QCNet/patch/qcnet.patch | 83 +++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 3 deletions(-) diff --git a/model_examples/QCNet/patch/qcnet.patch b/model_examples/QCNet/patch/qcnet.patch index c61844bb..425e8aab 100644 --- a/model_examples/QCNet/patch/qcnet.patch +++ b/model_examples/QCNet/patch/qcnet.patch @@ -1,5 +1,82 @@ +diff --git a/datamodules/argoverse_v2_dataloader.py b/datamodules/argoverse_v2_dataloader.py +new file mode 100644 +index 0000000..4e1e50f +--- /dev/null ++++ b/datamodules/argoverse_v2_dataloader.py +@@ -0,0 +1,70 @@ ++# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. ++# ------------------------------------------------------------------------ ++# Copyright (c) 2023, Zikang Zhou. All rights reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ------------------------------------------------------------------------ ++# Copyright (c) 2023 PyG Team. All rights reserved. ++# ------------------------------------------------------------------------ ++# Copyright (c) 2018 Matthias Fey. All rights reserved. ++# ------------------------------------------------------------------------ ++# Licensed under The MIT License [see LICENSE for details] ++ ++from collections.abc import Mapping ++from typing import TypeVar, Optional, Iterator, List, Optional, Sequence, Union ++from functools import partial ++import math ++import numpy as np ++import torch ++import torch.utils.data ++from torch.utils.data import DataLoader ++ ++from torch_geometric.data import Batch, Dataset ++from torch_geometric.data.data import BaseData ++from torch_geometric.data.datapipes import DatasetAdapter ++ ++from mx_driving.dataset.agent_dataset import AgentDynamicBatchSampler, DynamicBatchSampler, Collater ++ ++ ++class QCNetDynamicBatchSampler(AgentDynamicBatchSampler): ++ def bucket_arange(self): ++ g = torch.Generator() ++ if self.epoch >= 44: ++ self.seed=1 ++ g.manual_seed(self.epoch + self.seed) ++ # Shuffling buckets order. ++ bucket_index = torch.randperm(len(self.dataset.buckets), generator=g).tolist() ++ indices = [] ++ for bct_idx in bucket_index: ++ bucket = self.dataset.buckets[bct_idx] ++ # Shuffling samples in a bucket. ++ indices.extend([bucket[i] for i in torch.randperm(len(bucket), generator=g).tolist()]) ++ ++ return indices ++ ++ ++class QCNetDynamicBatchDataLoader(DataLoader): ++ def __init__(self, ++ dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], ++ batch_size: int, ++ train_batch_size: int, ++ shuffle: bool = True, ++ follow_batch: Optional[List[str]] = None, ++ exclude_keys: Optional[List[str]] = None, ++ **kwargs) -> None: ++ kwargs.pop('collate_fn', None) ++ kwargs.pop('batch_sampler', None) ++ ++ self.follow_batch = follow_batch ++ self.exclude_keys = exclude_keys ++ sampler = QCNetDynamicBatchSampler(dataset, shuffle=True) ++ ++ super().__init__( ++ dataset, collate_fn=Collater(follow_batch, exclude_keys), ++ batch_sampler=DynamicBatchSampler(dataset, sampler, train_batch_size), **kwargs) +\ No newline at end of file diff --git a/datamodules/argoverse_v2_datamodule.py b/datamodules/argoverse_v2_datamodule.py -index 1b55133..e3c92bb 100644 +index 1b55133..1efee1c 100644 --- a/datamodules/argoverse_v2_datamodule.py +++ b/datamodules/argoverse_v2_datamodule.py @@ -13,8 +13,9 @@ @@ -9,7 +86,7 @@ index 1b55133..e3c92bb 100644 -import pytorch_lightning as pl +import lightning.pytorch as pl from torch_geometric.loader import DataLoader -+from mx_driving.dataset import AgentDynamicBatchDataLoader ++from .argoverse_v2_dataloader import QCNetDynamicBatchDataLoader from datasets import ArgoverseV2Dataset from transforms import TargetBuilder @@ -51,7 +128,7 @@ index 1b55133..e3c92bb 100644 - num_workers=self.num_workers, pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers) + if self.dynamic_sort: -+ return AgentDynamicBatchDataLoader(self.train_dataset, batch_size=self.train_batch_size, train_batch_size=self.train_batch_size, shuffle=self.shuffle, ++ return QCNetDynamicBatchDataLoader(self.train_dataset, batch_size=self.train_batch_size, train_batch_size=self.train_batch_size, shuffle=self.shuffle, + num_workers=self.num_workers, pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers) + else: -- Gitee