diff --git a/test/torch_npu_schema.json b/test/torch_npu_schema.json index 7d85362d113ae8f2a5c1516a7b5fb2b9c88e17f8..64ff07640c7a8ec605a16dfec84147498cca4af6 100644 --- a/test/torch_npu_schema.json +++ b/test/torch_npu_schema.json @@ -2586,7 +2586,7 @@ "signature": "(query_layer, key_layer, value_layer, attention_mask, scale, keep_prob, query_transpose=False, key_transpose=False, bmm_score_transpose_a=False, bmm_score_transpose_b=False, value_transpose=False, dx_transpose=False)" }, "torch_npu.npu_fused_infer_attention_score": { - "signature": "(self, query, key, value, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, num_heads, scale, pre_tokens, next_tokens, input_layout, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode)" + "signature": "(self, query, key, value, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, num_heads, scale, pre_tokens, next_tokens, input_layout, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, key_rope_antiquant_scale, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode)" }, "torch_npu.npu_fusion_attention": { "signature": "(*args, **kwargs)" diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index a261d1785911b154eae6ee0a18e82b236de76921..e132c839ff6b4cbb43abd04d3b467aa76e466ef2 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -713,6 +713,7 @@ class _NPUFusedInferAttentionScoreOP(torch.autograd.Function): actual_shared_prefix_len: Optional[Tensor], query_rope: Optional[Tensor], key_rope: Optional[Tensor], + key_rope_antiquant_scale: Optional[Tensor], num_heads: int = 1, scale: float = 1.0, pre_tokens: int = 2147483647, next_tokens: int = 2147483647, input_layout: str = "BSH", num_key_value_heads: int = 0, @@ -724,7 +725,7 @@ class _NPUFusedInferAttentionScoreOP(torch.autograd.Function): dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, - key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, + key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, key_rope_antiquant_scale, num_heads, scale, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode) @@ -1299,7 +1300,7 @@ def _wrapper_npu_fused_infer_attention_score(self, query, key, value, pse_shift, antiquant_offset, block_table, query_padding_size, kv_padding_size, num_heads, scale, pre_tokens, next_tokens, input_layout, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, - key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, + key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, key_rope_antiquant_scale, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag, key_antiquant_mode, value_antiquant_mode): @@ -1309,7 +1310,7 @@ def _wrapper_npu_fused_infer_attention_score(self, query, key, value, pse_shift, quant_offset2, antiquant_scale, antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale, key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, - key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, + key_shared_prefix, value_shared_prefix, actual_shared_prefix_len, query_rope, key_rope, key_rope_antiquant_scale, num_heads, scale, pre_tokens, next_tokens, input_layout, num_key_value_heads, sparse_mode, inner_precise, block_size, antiquant_mode, softmax_lse_flag,