diff --git a/test/_inductor/test_exceptions.py b/test/_inductor/test_exceptions.py index 66a9621997a92f0d8ef3dac9b674131297982155..c4c60539163ef0703adb8c9da21e03b5dda489a6 100644 --- a/test/_inductor/test_exceptions.py +++ b/test/_inductor/test_exceptions.py @@ -31,7 +31,7 @@ import torch_npu 'constants': {}, 'mix_mode': 'aiv'}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_unk_fused_add_0', 'mutated_arg_names': [], 'backend_hash': 'bc71dba4086164e7ac2b0779fa861dbf7467f0265d4a57b8f48cf6dda02b150f', 'split_axis': [0], - 'tiling_axis': [0, 1], 'axis_names': ['y0', 'x1'], 'low_dims': {1}, 'numof_reduction_axis': 0, + 'tiling_axis': [0, 1], 'no_loop_axis': [1], 'axis_names': ['y0', 'x1'], 'low_dims': {1}, 'numof_reduction_axis': 0, 'split_axis_dtype': torch.float16, 'dual_reduction': False, 'traced_graph_hash': 'TRACED_GRAPH_HASH', 'traced_graph_dir': 'TRACED_GRAPH_DIR'}, min_elem_per_thread=0 diff --git a/torch_npu/_inductor/codegen/split_tiling.py b/torch_npu/_inductor/codegen/split_tiling.py index 782cc9f7455cd6d9f4eea63075768fdeb1af0690..addcbf5e572dae0dee93f4ac53639a947de4ca99 100644 --- a/torch_npu/_inductor/codegen/split_tiling.py +++ b/torch_npu/_inductor/codegen/split_tiling.py @@ -168,9 +168,44 @@ class SplitTiling: for i, x in enumerate(self.kernel.tiling_axis): x.tiling_order = i + # no_loop_axis 原则1:优先从low_dims tiling轴中选择小numel的轴 + # no_loop_axis 原则2:low_dims 轴仍未超过阈值,从tiling轴中选择其他轴 + # no_loop_axis 原则3:所有轴的所占空间预估小于等于4k时,无需loop + def select_no_loop_axis(self): + sorted_low_dims = [self.kernel.sorted_axis[dim] for dim in self.kernel.low_dims] + sorted_low_dims = sorted(sorted_low_dims, key=lambda x: x.length) + total_numels = 1 + def stop_loop(axis, current_numels): + if (axis.prefix == 'r' or + not axis.is_tiling_axis or + axis.is_split_axis or + axis.is_no_loop_axis): + return False, current_numels + current_numels *= axis.length + over_flow = current_numels > 4 * 1024 + if not over_flow: + axis.is_no_loop_axis = True + return over_flow, current_numels + + if self.kernel.persistent_reduction: + for axis in self.kernel.sorted_axis: + if axis.prefix == 'r': + total_numels *= axis.length + + for axis in sorted_low_dims: + overflow, total_numels = stop_loop(axis, total_numels) + if overflow: + return + + for axis in reversed(self.kernel.sorted_axis): + overflow, total_numels = stop_loop(axis, total_numels) + if overflow: + return + def select_split_tiling_axis(self): self.select_split_axis() self.select_tiling_axis() + self.select_no_loop_axis() # the below logic doesn't work when there're two reduction axis, but only one need outer reduction def should_outer_reduce_me(self, x): diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py index 0ac327904c77f9284da308f85507fb48f3a59731..78c274e8918d24df7553c512b1f4a22dd5278a74 100644 --- a/torch_npu/_inductor/codegen/triton.py +++ b/torch_npu/_inductor/codegen/triton.py @@ -196,11 +196,12 @@ class IterationRangesEntryNPUIndex(IterationRangesEntry): # don't use functools.lru_cache(None), so that previous indexing_code produdec by previous index, # could be overwritten self.codegen = self._codegen + self.is_no_loop_axis = False # axis mask def _codegen_mask(self): - if self.is_tiling_axis: + if self.is_tiling_axis and not self.is_no_loop_axis: BLOCK_NAME = f"{self.name.upper()}BLOCK" upper = f"min({BLOCK_NAME}+{self.symbol()}_offset, {self.name}_numel)" if self.is_split_axis else f"{self.name}_numel" line = f"{self.name}_mask = {self.name} < {upper}" @@ -277,7 +278,9 @@ class IterationRangesEntryNPUIndex(IterationRangesEntry): else: index = f"(loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}" else: - if self.is_split_axis: + if self.is_no_loop_axis: + index = f"base_{self.name}" + elif self.is_split_axis: offset = f"{self.symbol()}_offset" index = f"{offset} + (loop_{self.name} * {BLOCK_NAME_SUB}) + base_{self.name}" else: @@ -299,8 +302,10 @@ class IterationRangesEntryNPUIndex(IterationRangesEntry): if self.is_split_axis: lines.append(f"{self.symbol()}_offset = tl.program_id({self.split_order}) * {BLOCK_NAME}") - - if self.is_tiling_axis: + + if self.is_no_loop_axis: + lines.append(f"base_{self.name}= tl.arange(0, {BLOCK_NAME_SUB})") + elif self.is_tiling_axis: lines.append(f"base_{self.name}= tl.arange(0, {BLOCK_NAME_SUB})") block = f"{BLOCK_NAME}" if self.is_split_axis else f"{self.symbol()}_numel" lines.append(f"loops_{self.name} = ({block} + {BLOCK_NAME_SUB} - 1) // {BLOCK_NAME_SUB}") @@ -555,6 +560,7 @@ class NPUIndexTritonKernel(TritonKernel): mutated_args.add(self.args.output_buffers[mutation]) mutated_args = sorted(mutated_args) tiling_axis = [x.sorted_order for x in self.tiling_axis] + no_loop_axis = [x.sorted_order for x in self.tiling_axis if x.is_no_loop_axis] split_axis = [x.sorted_order for x in self.split_axis] axis_names = [x.name for x in self.sorted_axis] split_axis_dtype = self.get_axis_dtype(self.split_axis[0]) if self.split_axis else None @@ -567,6 +573,7 @@ class NPUIndexTritonKernel(TritonKernel): "backend_hash": self.patch_triton_hash(), # torch.utils._triton.triton_hash_with_backend(), "split_axis": split_axis, "tiling_axis": tiling_axis, + "no_loop_axis": no_loop_axis, "axis_names": axis_names, "low_dims": self.low_dims, "numof_reduction_axis": self.numof_reduction_axis(), @@ -636,6 +643,8 @@ class NPUIndexTritonKernel(TritonKernel): for axis in self.tiling_axis: if axis.name[0] == 'r' and self.persistent_reduction: continue + if axis.is_no_loop_axis: + continue argdefs.append(f"{axis.name.upper()}BLOCK_SUB: tl.constexpr") def _get_heuristic(self): @@ -776,8 +785,17 @@ class NPUIndexTritonKernel(TritonKernel): val = int(simplified_tree_numel) else: continue - val = next_power_of_2(val) code.writeline(f"{node.name.upper()}BLOCK_SUB: tl.constexpr = {val}") + + for axis in self.sorted_axis: + if axis.is_no_loop_axis: + simplified_tree_numel = V.graph.sizevars.simplify(axis.length) + if isinstance(simplified_tree_numel, (sympy.Integer, int)): + val = int(simplified_tree_numel) + else: + continue + code.writeline(f"{axis.name}_numel = {val}") + code.writeline(f"{axis.name.upper()}BLOCK_SUB: tl.constexpr = {val}") def lowest_axis_variable(self): if len(self.tiling_axis) == 0: @@ -880,7 +898,7 @@ class NPUIndexTritonKernel(TritonKernel): need_axis_loop = self.find_axis_in_load_store(range_val) if not need_axis_loop: indexing_code = None - if (range_val.prefix != 'r' or not self.persistent_reduction) and need_axis_loop: + if (range_val.prefix != 'r' or not self.persistent_reduction) and need_axis_loop and not range_val.is_no_loop_axis: self.body.splice(self.prefix) self.body.writeline(f"for loop_{range_val.name} in range(loops_{range_val.name}):") do_indent = True @@ -894,7 +912,7 @@ class NPUIndexTritonKernel(TritonKernel): if len(self.loads._lines) == 0 and len(self.stores._lines) == 0: do_indent = False indexing_code = None - if self.numof_reduction_axis() <= 1: + if self.numof_reduction_axis() <= 1 and not range_val.is_no_loop_axis: do_indent = True self.body.writeline(f"for loop_{range_val.name} in range(loops_{range_val.name}):") loop_body(index, indexing_code, is_last_axis, do_indent=do_indent) @@ -1117,9 +1135,17 @@ class NPUIndexTritonKernel(TritonKernel): return self.reduce_analysis.reduced_dim def filter_masks(self, mask_vars): + mask_vars_copy = mask_vars.copy() + def remove_mask_from_node(node_name): + for mask_var in mask_vars_copy: + if mask_var.startswith(node_name): + mask_vars.discard(mask_var) + for node in self.sorted_axis: - if not (node.is_tiling_axis): - mask_vars.discard(f"{node.name}_mask") + if ((not node.is_tiling_axis) or + (self.persistent_reduction and node.is_reduction) or + node.is_no_loop_axis): + remove_mask_from_node(node.name) def numof_reduction_axis(self): root = self.range_trees[-1]