diff --git a/torch_npu/testing/common_distributed.py b/torch_npu/testing/common_distributed.py index c19d3ff0336ab2ea54a07aae2d689e0136c28555..9a55c2f75fcac6337ddec5e5028816fba9ed354c 100644 --- a/torch_npu/testing/common_distributed.py +++ b/torch_npu/testing/common_distributed.py @@ -27,3 +27,14 @@ def skipIfUnsupportMultiNPU(npu_number_needed): return func(self) return wrapper return skip_dec + + +def skipOpOnlySupport910B(op_name): + def skip_dec(func): + def wrapper(self): + device_name = torch_npu.npu.get_device_name(0)[:10] + if device_name != 'Ascend910B': + raise unittest.SkipTest(f"The op {op_name} is only suppport on 910B, skip this ut for this device type!") + return func(self) + return wrapper + return skip_dec