diff --git a/test/npu/test_device.py b/test/npu/test_device.py index 57a0690381788bb97fbdc5957de32315b11d3ba9..ed8bde906f82850a35cf0730b0759ebddf00551d 100644 --- a/test/npu/test_device.py +++ b/test/npu/test_device.py @@ -101,22 +101,53 @@ class TestDevice(TestCase): assert (hash(torch.device) == hash(origin_device)) assert (f'{torch.device}' == f'{origin_device}') - def test_multithread_device(self): + def test_multithread_device_with_set_device(self): import threading def _worker(result): try: + torch.npu.set_device("npu:0") cur = torch_npu.npu.current_device() self.assertEqual(cur, 0) except Exception: result[0] = 1 - result = [0] - torch.npu.set_device("npu:0") + result = [0, 0] + + try: + torch.npu.set_device("npu:0") + cur = torch_npu.npu.current_device() + self.assertEqual(cur, 0) + except Exception: + result[1] = 1 + thread = threading.Thread(target=_worker, args=(result,)) + thread.start() + thread.join() + self.assertEqual(result[0], 0) + self.assertEqual(result[1], 0) + + def test_multithread_device_with_no_device(self): + import threading + + def _worker(result): + try: + cur = torch_npu.npu.current_device() + self.assertEqual(cur, 0) + except Exception: + result[0] = 1 + + result = [0, 0] + + try: + cur = torch_npu.npu.current_device() + self.assertEqual(cur, 0) + except Exception: + result[1] = 1 thread = threading.Thread(target=_worker, args=(result,)) thread.start() thread.join() self.assertEqual(result[0], 0) + self.assertEqual(result[1], 0) if __name__ == '__main__': diff --git a/torch_npu/acl.json b/torch_npu/acl.json index 8a77faac4fa153432b380a418d81361586790f59..fc0b7aa696d54dc995063905175143b951e09d91 100644 --- a/torch_npu/acl.json +++ b/torch_npu/acl.json @@ -1 +1,5 @@ -{"dump":{"dump_scene":"lite_exception"}} \ No newline at end of file +{ + "dump": { + "dump_scene": "lite_exception" + } +} \ No newline at end of file