diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 91d660d3d5a1d838a2480089fa6e2b3923c39dd8..8a166a4af08ff3d490167912dd16936a40d96dd5 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -636,7 +636,24 @@ class NPUFlashAttentionOP(torch.autograd.Function): head_num_i=head_num, input_layout_s=input_layout, scale_f=scale, keep_prob_f=keep_prob, pre_tockens_i=pre_tockens, next_tockens_i=next_tockens, gen_mask_parallel_i=gen_mask_parallel, sync_i=sync) +# npu_masked_softmax_with_rel_pos_bias(Tensor x, Tensor relative_pos_bias, Tensor? atten_mask=None, float scale_value=1.0, int inner_precision_mode=0) +class NPUMaskedSoftmaxWithRelPosBiasOP(torch.autograd.Function): + @staticmethod + def forward(ctx, *args, **kwargs): + return torch_npu._C._VariableFunctionsClass.npu_masked_softmax_with_rel_pos_bias(*args, **kwargs) + + @staticmethod + def symbolic(g, x: Tensor, relative_pos_bias: Tensor, atten_mask: Tensor, scale_value: float = 1.0, + inner_precision_mode: int = 0): + if atten_mask is None: + atten_mask = g.op("Constant", value_t=torch.tensor([]).to(torch.float)) + return g.op("npu::NPUMaskedSoftmaxWithRelPosBias", x, relative_pos_bias, atten_mask, scale_value_f = scale_value, + inner_precision_mode_i = inner_precision_mode) + + +def wrapper_npu_masked_softmax_with_rel_pos_bias(x, relative_pos_bias, atten_mask=None, scale_value=1.0, inner_precision_mode=0): + return NPUMaskedSoftmaxWithRelPosBiasOP.apply(x, relative_pos_bias, atten_mask, scale_value, inner_precision_mode) def wrapper_npu_flash_attention(query, key, value, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None, @@ -916,3 +933,4 @@ def add_onnx_ops(): torch_npu.npu_mish = wrapper_npu_mish torch_npu.npu_rotary_mul = wrapper_npu_rotary_mul torch_npu.npu_flash_attention = wrapper_npu_flash_attention + torch_npu.npu_flash_attention = wrapper_npu_masked_softmax_with_rel_pos_bias