diff --git a/model_examples/DenseTNT/DenseTNT_npu.patch b/model_examples/DenseTNT/DenseTNT_npu.patch index 40badb3f540106a415f9d9c330bc8111f2122072..4f6d30ae1c919f885747e5e90a2870e9e26c8019 100644 --- a/model_examples/DenseTNT/DenseTNT_npu.patch +++ b/model_examples/DenseTNT/DenseTNT_npu.patch @@ -99,7 +99,7 @@ index 7f5e435..4f6e66a 100644 assert hidden_states_query.shape[1] == attention_mask.shape[1] \ and hidden_states_key.shape[1] == attention_mask.shape[2] diff --git a/src/modeling/vectornet.py b/src/modeling/vectornet.py -index 502f936..23b88c4 100644 +index 502f936..794dd30 100644 --- a/src/modeling/vectornet.py +++ b/src/modeling/vectornet.py @@ -20,8 +20,8 @@ class NewSubGraph(nn.Module): @@ -135,6 +135,65 @@ index 502f936..23b88c4 100644 for layer_index, layer in enumerate(self.layers): temp = hidden_states +@@ -93,6 +100,7 @@ class VectorNet(nn.Module): + :param polyline_spans: vectors of i_th element is matrix[polyline_spans[i]] + :return: hidden states of all elements and hidden states of lanes + """ ++ preprocessed_matrix = [torch.tensor(mat, device=device) for mat in matrix] + input_list_list = [] + # TODO(cyrushx): This is not used? Is it because input_list_list includes map data as well? + # Yes, input_list_list includes map data, this will be used in the future release. +@@ -101,9 +109,11 @@ class VectorNet(nn.Module): + for i in range(batch_size): + input_list = [] + map_input_list = [] ++ current_matrix = preprocessed_matrix[i] + map_start_polyline_idx = mapping[i]['map_start_polyline_idx'] ++ + for j, polyline_span in enumerate(polyline_spans[i]): +- tensor = torch.tensor(matrix[i][polyline_span], device=device) ++ tensor = current_matrix[polyline_span] + input_list.append(tensor) + if j >= map_start_polyline_idx: + map_input_list.append(tensor) +@@ -111,11 +121,10 @@ class VectorNet(nn.Module): + input_list_list.append(input_list) + map_input_list_list.append(map_input_list) + +- if True: +- element_states_batch = [] +- for i in range(batch_size): +- a, b = self.point_level_sub_graph(input_list_list[i]) +- element_states_batch.append(a) ++ element_states_batch = [] ++ for i in range(batch_size): ++ a, b = self.point_level_sub_graph(input_list_list[i]) ++ element_states_batch.append(a) + + if 'lane_scoring' in args.other_params: + lane_states_batch = [] +@@ -126,16 +135,12 @@ class VectorNet(nn.Module): + # We follow laneGCN to fuse realtime traffic information from agent nodes to lane nodes. + if 'laneGCN' in args.other_params: + for i in range(batch_size): +- map_start_polyline_idx = mapping[i]['map_start_polyline_idx'] +- agents = element_states_batch[i][:map_start_polyline_idx] +- lanes = element_states_batch[i][map_start_polyline_idx:] ++ map_start = mapping[i]['map_start_polyline_idx'] ++ agents = element_states_batch[i][:map_start] ++ lanes = element_states_batch[i][map_start:] + # Origin laneGCN contains three fusion layers. Here one fusion layer is enough. +- if True: +- lanes = lanes + self.laneGCN_A2L(lanes.unsqueeze(0), torch.cat([lanes, agents[0:1]]).unsqueeze(0)).squeeze(0) +- else: +- lanes = lanes + self.laneGCN_A2L(lanes.unsqueeze(0), agents.unsqueeze(0)).squeeze(0) +- lanes = lanes + self.laneGCN_L2L(lanes.unsqueeze(0)).squeeze(0) +- agents = agents + self.laneGCN_L2A(agents.unsqueeze(0), lanes.unsqueeze(0)).squeeze(0) ++ lanes = lanes + self.laneGCN_A2L(lanes.unsqueeze(0), ++ torch.cat([lanes, agents[0:1]]).unsqueeze(0)).squeeze(0) + element_states_batch[i] = torch.cat([agents, lanes]) + + return element_states_batch, lane_states_batch diff --git a/src/run.py b/src/run.py index 989fe1b..f4c92b6 100644 --- a/src/run.py @@ -158,7 +217,7 @@ index 989fe1b..f4c92b6 100644 # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) diff --git a/src/utils.py b/src/utils.py -index 4290911..9b91d12 100644 +index 4290911..6d8aa66 100644 --- a/src/utils.py +++ b/src/utils.py @@ -276,7 +276,6 @@ def init(args_: Args, logger_): @@ -181,3 +240,18 @@ index 4290911..9b91d12 100644 def add_eval_param(param): if param not in args.eval_params: +@@ -813,7 +817,13 @@ def merge_tensors(tensors: List[torch.Tensor], device, hidden_size=None) -> Tupl + + + def de_merge_tensors(tensor: Tensor, lengths): +- return [tensor[i, :lengths[i]] for i in range(len(lengths))] ++ lengths_tensor = torch.tensor(lengths, dtype=torch.long) ++ B = tensor.size(0) ++ N = tensor.size(1) ++ col_indices = torch.arange(N).expand(B, -1) ++ mask = col_indices < lengths_tensor.unsqueeze(1) ++ selected = tensor[mask] ++ return torch.split(selected, lengths, dim=0) + + + def gather_tensors(tensor: torch.Tensor, indices: List[list]):