diff --git a/modules/devices.py b/modules/devices.py index 62dc9f42..7c09d1f4 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -38,15 +38,9 @@ def get_device_for(task): def torch_gc(): model_management.soft_empty_cache() - if npu_specific.has_npu: - torch_npu_set_device() - npu_specific.torch_npu_gc() - def torch_npu_set_device(): - # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue - if npu_specific.has_npu: - torch.npu.set_device(0) + return def enable_tf32():