diff --git a/torch_npu/contrib/function/fusion_attention.py b/torch_npu/contrib/function/fusion_attention.py index ddf8453763f8326f780c3365dcffc86cbccd8ebc..1674cd723aa5a343a38ff1c55565f306802a4c8f 100644 --- a/torch_npu/contrib/function/fusion_attention.py +++ b/torch_npu/contrib/function/fusion_attention.py @@ -18,23 +18,25 @@ import torch_npu def npu_fusion_attention(query, key, value, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None, scale=1., keep_prob=1., pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0, - gen_mask_parallel=True, sync=False): + gen_mask_parallel=True, sync=False, pse_type=0, q_start_idx=None, kv_start_idx=None): return torch_npu.npu_flash_attention( query, key, value, head_num, input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, scale=scale, keep_prob=keep_prob, pre_tockens=pre_tockens, next_tockens=next_tockens, inner_precise=inner_precise, prefix=prefix, actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen, - sparse_mode=sparse_mode, gen_mask_parallel=gen_mask_parallel, sync=sync) + sparse_mode=sparse_mode, gen_mask_parallel=gen_mask_parallel, sync=sync, pse_type=pse_type, + q_start_idx=q_start_idx, kv_start_idx=kv_start_idx) def npu_fusion_attention_grad(query, key, value, dy, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None, attention_in=None, scale_value=1., keep_prob=1., pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, seed=0, offset=0, numels=0, prefix=None, actual_seq_qlen=None, - actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False): + actual_seq_kvlen=None, sparse_mode=0, gen_mask_parallel=True, sync=False, + pse_type=0, q_start_idx=None, kv_start_idx=None): return torch_npu.npu_flash_attention_grad( query, key, value, dy, head_num, input_layout, pse=pse, padding_mask=padding_mask, atten_mask=atten_mask, softmax_max=softmax_max, softmax_sum=softmax_sum, softmax_in=softmax_in, attention_in=attention_in, scale_value=scale_value, keep_prob=keep_prob, pre_tockens=pre_tockens, next_tockens=next_tockens, inner_precise=inner_precise, seed=seed, offset=offset, numels=numels, prefix=prefix, actual_seq_qlen=actual_seq_qlen, actual_seq_kvlen=actual_seq_kvlen, sparse_mode=sparse_mode, - gen_mask_parallel=gen_mask_parallel, sync=sync) + gen_mask_parallel=gen_mask_parallel, sync=sync, pse_type=pse_type, q_start_idx=q_start_idx, kv_start_idx=kv_start_idx)