diff --git a/modules/devices.py b/modules/devices.py index dfffaf24..28c0c54d 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,8 +3,7 @@ import contextlib from functools import lru_cache import torch -from modules import errors, shared -from modules import torch_utils +from modules import errors, shared, npu_specific if sys.platform == "darwin": from modules import mac_specific @@ -58,6 +57,9 @@ def get_optimal_device_name(): if has_xpu(): return xpu_specific.get_xpu_device_string() + if npu_specific.has_npu: + return npu_specific.get_npu_device_string() + return "cpu" @@ -85,6 +87,16 @@ def torch_gc(): if has_xpu(): xpu_specific.torch_xpu_gc() + if npu_specific.has_npu: + torch_npu_set_device() + npu_specific.torch_npu_gc() + + +def torch_npu_set_device(): + # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue + if npu_specific.has_npu: + torch.npu.set_device(0) + def enable_tf32(): if torch.cuda.is_available(): @@ -141,7 +153,12 @@ def manual_cast_forward(target_dtype): args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - org_dtype = torch_utils.get_param(self).dtype + org_dtype = target_dtype + for param in self.parameters(): + if param.dtype != target_dtype: + org_dtype = param.dtype + break + if org_dtype != target_dtype: self.to(target_dtype) result = self.org_forward(*args, **kwargs) @@ -170,7 +187,7 @@ def manual_cast(target_dtype): continue applied = True org_forward = module_type.forward - if module_type == torch.nn.MultiheadAttention and has_xpu(): + if module_type == torch.nn.MultiheadAttention: module_type.forward = manual_cast_forward(torch.float32) else: module_type.forward = manual_cast_forward(target_dtype) @@ -252,4 +269,3 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) - diff --git a/modules/initialize.py b/modules/initialize.py index 7c1ac99e..f7313ff4 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -142,13 +142,14 @@ def initialize_rest(*, reload_script_modules=False): its optimization may be None because the list of optimizaers has neet been filled by that time, so we apply optimization again. """ + from modules import devices + devices.torch_npu_set_device() shared.sd_model # noqa: B018 if sd_hijack.current_optimizer is None: sd_hijack.apply_optimizations() - from modules import devices devices.first_time_calculation() if not shared.cmd_opts.skip_load_model_at_start: Thread(target=load_model).start() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 3ff4576a..107c72b0 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -338,6 +338,7 @@ def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") + requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt") xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") @@ -421,6 +422,13 @@ def prepare_environment(): run_pip(f"install -r \"{requirements_file}\"", "requirements") startup_timer.record("install requirements") + if not os.path.isfile(requirements_file_for_npu): + requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu) + + if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu): + run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu") + startup_timer.record("install requirements_for_npu") + if not args.skip_install: run_extensions_installers(settings_file=args.ui_settings_file) diff --git a/modules/npu_specific.py b/modules/npu_specific.py new file mode 100644 index 00000000..94100691 --- /dev/null +++ b/modules/npu_specific.py @@ -0,0 +1,31 @@ +import importlib +import torch + +from modules import shared + + +def check_for_npu(): + if importlib.util.find_spec("torch_npu") is None: + return False + import torch_npu + + try: + # Will raise a RuntimeError if no NPU is found + _ = torch_npu.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + + +def get_npu_device_string(): + if shared.cmd_opts.device_id is not None: + return f"npu:{shared.cmd_opts.device_id}" + return "npu:0" + + +def torch_npu_gc(): + with torch.npu.device(get_npu_device_string()): + torch.npu.empty_cache() + + +has_npu = check_for_npu() diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index ef237396..941dff4b 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -94,7 +94,7 @@ class CFGDenoiser(torch.nn.Module): def pad_cond_uncond(self, cond, uncond): empty = shared.sd_model.cond_stage_model_empty_prompt - num_repeats = (cond.shape[1] - cond.shape[1]) // empty.shape[1] + num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1] if num_repeats < 0: cond = pad_cond(cond, -num_repeats, empty) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c6bcab15..6d815c0b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -150,6 +150,7 @@ class EmbeddingDatabase: return embedding def get_expected_shape(self): + devices.torch_npu_set_device() vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) return vec.shape[1] diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index fbe705be..457fbf52 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -96,9 +96,9 @@ class Toprow: with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box: self.submit_box = submit_box - self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt") - self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip") - self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary') + self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt", tooltip="End generation immediately or after completing current batch") + self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip", tooltip="Stop generation of current batch and continues onto next batch") + self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary', tooltip="Right click generate forever menu") self.skip.click( fn=lambda: shared.state.skip(), diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index afed8b40..b5e5a80c 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images, shared, torch_utils +from modules import devices, images, shared, torch_utils logger = logging.getLogger(__name__) @@ -44,7 +44,8 @@ def upscale_pil_patch(model, img: Image.Image) -> Image.Image: with torch.no_grad(): tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension tensor = tensor.to(device=param.device, dtype=param.dtype) - return torch_bgr_to_pil_image(model(tensor)) + with devices.without_autocast(): + return torch_bgr_to_pil_image(model(tensor)) def upscale_with_model( diff --git a/requirements_npu.txt b/requirements_npu.txt new file mode 100644 index 00000000..5e6a4364 --- /dev/null +++ b/requirements_npu.txt @@ -0,0 +1,4 @@ +cloudpickle +decorator +synr==0.5.0 +tornado diff --git a/requirements_versions.txt b/requirements_versions.txt index 2a922f28..5e30b5ea 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -19,7 +19,7 @@ piexif==1.1.3 psutil==5.9.5 pytorch_lightning==1.9.4 resize-right==0.0.2 -safetensors==0.3.1 +safetensors==0.4.2 scikit-image==0.21.0 spandrel==0.1.6 tomesd==0.1.3 diff --git a/webui.sh b/webui.sh index 38258ef6..25b94906 100755 --- a/webui.sh +++ b/webui.sh @@ -158,6 +158,10 @@ then if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] then export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2" + elif echo "$gpu_info" | grep -q "Huawei" && [[ -z "${TORCH_COMMAND}" ]] + then + export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu" + fi fi