diff --git a/test/test_custom_ops/test_prompt_flash_attention.py b/test/test_custom_ops/test_prompt_flash_attention.py index 9586327d87b8e32a0e00d630e656fee6457e5064..936c0e862b22a35ee88f66794341e81c5148e444 100644 --- a/test/test_custom_ops/test_prompt_flash_attention.py +++ b/test/test_custom_ops/test_prompt_flash_attention.py @@ -20,7 +20,7 @@ class TestPromptFlashAttention(TestCase): def custom_op_exec(self, query, key, value, head_dim): scale = 1 / 0.0078125 return torch_npu.npu_prompt_flash_attention( - query, key, value, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) + query, key, value, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535, sparse_mode=0) def test_npu_prompt_flash_attention(self, device="npu"): query = torch.randn(1, 32, 2048, 128, dtype=torch.float16).npu() diff --git a/test/test_network_ops/test_prompt_flash_attention.py b/test/test_network_ops/test_prompt_flash_attention.py index b47f8c6acdbb3572376e62e2fd7e5b1266b883ff..837284cdee8f8be005c43ba7a75a93fb366e2854 100644 --- a/test/test_network_ops/test_prompt_flash_attention.py +++ b/test/test_network_ops/test_prompt_flash_attention.py @@ -19,7 +19,7 @@ class TestPromptFlashAttetion(TestCase): def prompt_flash_attention_npu(self, q, k, v, head_dim): scale = 1 / 0.0078125 return torch_npu.npu_prompt_flash_attention( - q, k, v, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) + q, k, v, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535, sparse_mode=0) def test_op_exec(self): q = torch.randn(1, 32, 2048, 128, dtype=torch.float16).npu() diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 23a0d77ab3bbc4ae8f9d2c71626059e58c7df01a..c62a54c6c80a8493a26fac3bac7a424b223c5b56 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -11,7 +11,7 @@ def npu_incre_flash_attention_forward(query, key, value, *, padding_mask=None, a @impl(m, "npu_prompt_flash_attention") -def npu_prompt_flash_attention_forward(query, key, value, *, padding_mask=None, atten_mask=None, actual_seq_lengths=None, num_heads=1, scale_value=1.0, pre_tokens=2147473647, next_tokens=0, input_layout="BSH", num_key_value_heads=0): +def npu_prompt_flash_attention_forward(query, key, value, *, padding_mask=None, atten_mask=None, actual_seq_lengths=None, actual_seq_lengths_kv=None, deq_scale1=None, quant_scale1=None, deq_scale2=None, quant_scale2=None, quant_offset2=None, num_heads=1, scale_value=1.0, pre_tokens=2147473647, next_tokens=0, input_layout="BSH", num_key_value_heads=0, sparse_mode=0): return torch.empty_like(query, dtype=query.dtype) diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 660f38934aafc66963e3cb26bec03c3be9f0208e..3640e3bfefaf61385e105a902e17d4f31a072852 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -598,13 +598,19 @@ class NPUPromptFlashAttentionOP(torch.autograd.Function): @staticmethod def symbolic(g, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[Tensor], atten_mask: Optional[Tensor], - actual_seq_lengths: Optional[Tensor], num_heads: int = 1, + actual_seq_lengths: Optional[Tensor], actual_seq_lengths_kv: Optional[Tensor], + deq_scale1: Optional[Tensor], quant_scale1: Optional[Tensor], + deq_scale2: Optional[Tensor], quant_scale2: Optional[Tensor], + quant_offset2: Optional[Tensor], + num_heads: int = 1, scale_value: float = 1.0, pre_tokens: int = 2147473647, next_tokens: int = 0, - input_layout: str = "BSH", num_key_value_heads: int = 0): - return g.op("npu::NPUPromptFlashAttention", self, query, key, value, - padding_mask, atten_mask, actual_seq_lengths, - num_heads, scale_value, pre_tokens, next_tokens, - input_layout, num_key_value_heads) + input_layout: str = "BSH", num_key_value_heads: int = 0, + sparse_mode: int = 0): + return g.op("npu::NPUPromptFlashAttention", self, query, key, value, + padding_mask, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, + deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, + num_heads, scale_value, pre_tokens, next_tokens, + input_layout, num_key_value_heads, sparse_mode) class NPUIncreFlashAttentionOP(torch.autograd.Function): @@ -842,10 +848,15 @@ def wrapper_npu_rotary_mul(x, r1, r2): return NPURotaryMulOP.apply(x, r1, r2) -def wrapper_npu_prompt_flash_attention(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths, - num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads): - return NPUPromptFlashAttentionOP.apply(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths, - num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads) +def wrapper_npu_prompt_flash_attention(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, + deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, + num_heads, scale_value, pre_tokens, next_tokens, + input_layout, num_key_value_heads, sparse_mode): + return NPUPromptFlashAttentionOP.apply(self, query, key, value, padding_mask, atten_mask, + actual_seq_lengths, actual_seq_lengths_kv, + deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, + num_heads, scale_value, pre_tokens, next_tokens, + input_layout, num_key_value_heads, sparse_mode) def wrapper_npu_incre_flash_attention(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths,