From 322c2044283518394b176c6b064ddd2381087817 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=BD=AC=E6=A6=95?= Date: Thu, 26 Oct 2023 17:16:41 +0800 Subject: [PATCH] update torch_npu/onnx/wrapper_onnx_ops.py. --- .../test_prompt_flash_attention.py | 2 +- .../test_prompt_flash_attention.py | 2 +- torch_npu/meta/meta_registrations.py | 2 +- torch_npu/onnx/wrapper_onnx_ops.py | 31 +++++++++++++------ 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/test/test_custom_ops/test_prompt_flash_attention.py b/test/test_custom_ops/test_prompt_flash_attention.py index 9586327d87..936c0e862b 100644 --- a/test/test_custom_ops/test_prompt_flash_attention.py +++ b/test/test_custom_ops/test_prompt_flash_attention.py @@ -20,7 +20,7 @@ class TestPromptFlashAttention(TestCase): def custom_op_exec(self, query, key, value, head_dim): scale = 1 / 0.0078125 return torch_npu.npu_prompt_flash_attention( - query, key, value, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) + query, key, value, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535, sparse_mode=0) def test_npu_prompt_flash_attention(self, device="npu"): query = torch.randn(1, 32, 2048, 128, dtype=torch.float16).npu() diff --git a/test/test_network_ops/test_prompt_flash_attention.py b/test/test_network_ops/test_prompt_flash_attention.py index b47f8c6acd..837284cdee 100644 --- a/test/test_network_ops/test_prompt_flash_attention.py +++ b/test/test_network_ops/test_prompt_flash_attention.py @@ -19,7 +19,7 @@ class TestPromptFlashAttetion(TestCase): def prompt_flash_attention_npu(self, q, k, v, head_dim): scale = 1 / 0.0078125 return torch_npu.npu_prompt_flash_attention( - q, k, v, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) + q, k, v, num_heads=32, input_layout="BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535, sparse_mode=0) def test_op_exec(self): q = torch.randn(1, 32, 2048, 128, dtype=torch.float16).npu() diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 23a0d77ab3..c62a54c6c8 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -11,7 +11,7 @@ def npu_incre_flash_attention_forward(query, key, value, *, padding_mask=None, a @impl(m, "npu_prompt_flash_attention") -def npu_prompt_flash_attention_forward(query, key, value, *, padding_mask=None, atten_mask=None, actual_seq_lengths=None, num_heads=1, scale_value=1.0, pre_tokens=2147473647, next_tokens=0, input_layout="BSH", num_key_value_heads=0): +def npu_prompt_flash_attention_forward(query, key, value, *, padding_mask=None, atten_mask=None, actual_seq_lengths=None, actual_seq_lengths_kv=None, deq_scale1=None, quant_scale1=None, deq_scale2=None, quant_scale2=None, quant_offset2=None, num_heads=1, scale_value=1.0, pre_tokens=2147473647, next_tokens=0, input_layout="BSH", num_key_value_heads=0, sparse_mode=0): return torch.empty_like(query, dtype=query.dtype) diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 660f38934a..3640e3bfef 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -598,13 +598,19 @@ class NPUPromptFlashAttentionOP(torch.autograd.Function): @staticmethod def symbolic(g, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[Tensor], atten_mask: Optional[Tensor], - actual_seq_lengths: Optional[Tensor], num_heads: int = 1, + actual_seq_lengths: Optional[Tensor], actual_seq_lengths_kv: Optional[Tensor], + deq_scale1: Optional[Tensor], quant_scale1: Optional[Tensor], + deq_scale2: Optional[Tensor], quant_scale2: Optional[Tensor], + quant_offset2: Optional[Tensor], + num_heads: int = 1, scale_value: float = 1.0, pre_tokens: int = 2147473647, next_tokens: int = 0, - input_layout: str = "BSH", num_key_value_heads: int = 0): - return g.op("npu::NPUPromptFlashAttention", self, query, key, value, - padding_mask, atten_mask, actual_seq_lengths, - num_heads, scale_value, pre_tokens, next_tokens, - input_layout, num_key_value_heads) + input_layout: str = "BSH", num_key_value_heads: int = 0, + sparse_mode: int = 0): + return g.op("npu::NPUPromptFlashAttention", self, query, key, value, + padding_mask, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, + deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, + num_heads, scale_value, pre_tokens, next_tokens, + input_layout, num_key_value_heads, sparse_mode) class NPUIncreFlashAttentionOP(torch.autograd.Function): @@ -842,10 +848,15 @@ def wrapper_npu_rotary_mul(x, r1, r2): return NPURotaryMulOP.apply(x, r1, r2) -def wrapper_npu_prompt_flash_attention(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths, - num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads): - return NPUPromptFlashAttentionOP.apply(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths, - num_heads, scale_value, pre_tokens, next_tokens, input_layout, num_key_value_heads) +def wrapper_npu_prompt_flash_attention(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths, actual_seq_lengths_kv, + deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, + num_heads, scale_value, pre_tokens, next_tokens, + input_layout, num_key_value_heads, sparse_mode): + return NPUPromptFlashAttentionOP.apply(self, query, key, value, padding_mask, atten_mask, + actual_seq_lengths, actual_seq_lengths_kv, + deq_scale1, quant_scale1, deq_scale2, quant_scale2, quant_offset2, + num_heads, scale_value, pre_tokens, next_tokens, + input_layout, num_key_value_heads, sparse_mode) def wrapper_npu_incre_flash_attention(self, query, key, value, padding_mask, atten_mask, actual_seq_lengths, -- Gitee