mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-09 06:59:48 +00:00
move to new backend - part 1
This commit is contained in:
@@ -3,9 +3,7 @@ import json
|
||||
import os
|
||||
from modules.paths_internal import normalized_filepath, models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||
from pathlib import Path
|
||||
from ldm_patched.modules import args_parser
|
||||
|
||||
parser = args_parser.parser
|
||||
from backend.args import parser
|
||||
|
||||
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
|
||||
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
|
||||
|
||||
@@ -5,8 +5,8 @@ import torch
|
||||
import numpy as np
|
||||
|
||||
from modules import modelloader, paths, deepbooru_model, images, shared
|
||||
from ldm_patched.modules import model_management
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
from backend import memory_management
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
@@ -15,11 +15,11 @@ re_special = re.compile(r'([\\()])')
|
||||
class DeepDanbooru:
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.load_device = model_management.text_encoder_device()
|
||||
self.offload_device = model_management.text_encoder_offload_device()
|
||||
self.load_device = memory_management.text_encoder_device()
|
||||
self.offload_device = memory_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
|
||||
if model_management.should_use_fp16(device=self.load_device):
|
||||
if memory_management.should_use_fp16(device=self.load_device):
|
||||
self.dtype = torch.float16
|
||||
|
||||
self.patcher = None
|
||||
@@ -45,7 +45,7 @@ class DeepDanbooru:
|
||||
|
||||
def start(self):
|
||||
self.load()
|
||||
model_management.load_models_gpu([self.patcher])
|
||||
memory_management.load_models_gpu([self.patcher])
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import contextlib
|
||||
import torch
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
def has_xpu() -> bool:
|
||||
return model_management.xpu_available
|
||||
return memory_management.xpu_available
|
||||
|
||||
|
||||
def has_mps() -> bool:
|
||||
return model_management.mps_mode()
|
||||
return memory_management.mps_mode()
|
||||
|
||||
|
||||
def cuda_no_autocast(device_id=None) -> bool:
|
||||
@@ -16,27 +16,27 @@ def cuda_no_autocast(device_id=None) -> bool:
|
||||
|
||||
|
||||
def get_cuda_device_id():
|
||||
return model_management.get_torch_device().index
|
||||
return memory_management.get_torch_device().index
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
return str(model_management.get_torch_device())
|
||||
return str(memory_management.get_torch_device())
|
||||
|
||||
|
||||
def get_optimal_device_name():
|
||||
return model_management.get_torch_device().type
|
||||
return memory_management.get_torch_device().type
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
return model_management.get_torch_device()
|
||||
return memory_management.get_torch_device()
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
return model_management.get_torch_device()
|
||||
return memory_management.get_torch_device()
|
||||
|
||||
|
||||
def torch_gc():
|
||||
model_management.soft_empty_cache()
|
||||
memory_management.soft_empty_cache()
|
||||
|
||||
|
||||
def torch_npu_set_device():
|
||||
@@ -49,15 +49,15 @@ def enable_tf32():
|
||||
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = model_management.get_torch_device()
|
||||
device_interrogate: torch.device = model_management.text_encoder_device() # for backward compatibility, not used now
|
||||
device_gfpgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
device_esrgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
device_codeformer: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
dtype: torch.dtype = model_management.unet_dtype()
|
||||
dtype_vae: torch.dtype = model_management.vae_dtype()
|
||||
dtype_unet: torch.dtype = model_management.unet_dtype()
|
||||
dtype_inference: torch.dtype = model_management.unet_dtype()
|
||||
device: torch.device = memory_management.get_torch_device()
|
||||
device_interrogate: torch.device = memory_management.text_encoder_device() # for backward compatibility, not used now
|
||||
device_gfpgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
device_esrgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
device_codeformer: torch.device = memory_management.get_torch_device() # will be managed by memory management system
|
||||
dtype: torch.dtype = memory_management.unet_dtype()
|
||||
dtype_vae: torch.dtype = memory_management.vae_dtype()
|
||||
dtype_unet: torch.dtype = memory_management.unet_dtype()
|
||||
dtype_inference: torch.dtype = memory_management.unet_dtype()
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from modules import devices, paths, shared, modelloader, errors
|
||||
from ldm_patched.modules import model_management
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
from backend import memory_management
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
|
||||
blip_image_eval_size = 384
|
||||
@@ -57,11 +57,11 @@ class InterrogateModels:
|
||||
self.skip_categories = []
|
||||
self.content_dir = content_dir
|
||||
|
||||
self.load_device = model_management.text_encoder_device()
|
||||
self.offload_device = model_management.text_encoder_offload_device()
|
||||
self.load_device = memory_management.text_encoder_device()
|
||||
self.offload_device = memory_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
|
||||
if model_management.should_use_fp16(device=self.load_device):
|
||||
if memory_management.should_use_fp16(device=self.load_device):
|
||||
self.dtype = torch.float16
|
||||
|
||||
self.blip_patcher = None
|
||||
@@ -137,7 +137,7 @@ class InterrogateModels:
|
||||
self.clip_model = self.clip_model.to(device=self.offload_device, dtype=self.dtype)
|
||||
self.clip_patcher = ModelPatcher(self.clip_model, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
model_management.load_models_gpu([self.blip_patcher, self.clip_patcher])
|
||||
memory_management.load_models_gpu([self.blip_patcher, self.clip_patcher])
|
||||
return
|
||||
|
||||
def send_clip_to_ram(self):
|
||||
|
||||
@@ -63,13 +63,3 @@ for d, must_exist, what, options in path_dirs:
|
||||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
||||
|
||||
import ldm_patched.utils.path_utils as ldm_patched_path_utils
|
||||
|
||||
ldm_patched_path_utils.base_path = data_path
|
||||
ldm_patched_path_utils.models_dir = models_path
|
||||
ldm_patched_path_utils.output_directory = os.path.join(data_path, "output")
|
||||
ldm_patched_path_utils.temp_directory = os.path.join(data_path, "temp")
|
||||
ldm_patched_path_utils.input_directory = os.path.join(data_path, "input")
|
||||
ldm_patched_path_utils.user_directory = os.path.join(data_path, "user")
|
||||
|
||||
@@ -14,13 +14,11 @@ import ldm.modules.midas as midas
|
||||
import gc
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.shared import opts
|
||||
from modules.timer import Timer
|
||||
import numpy as np
|
||||
from modules_forge import forge_loader
|
||||
import modules_forge.ops as forge_ops
|
||||
from ldm_patched.modules.ops import manual_cast
|
||||
from ldm_patched.modules import model_management as model_management
|
||||
import ldm_patched.modules.model_patcher
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
@@ -650,8 +648,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
|
||||
model_data.sd_model = None
|
||||
model_data.loaded_sd_models = []
|
||||
model_management.unload_all_models()
|
||||
model_management.soft_empty_cache()
|
||||
memory_management.unload_all_models()
|
||||
memory_management.soft_empty_cache()
|
||||
gc.collect()
|
||||
|
||||
timer.record("unload existing model")
|
||||
@@ -724,7 +722,7 @@ def apply_token_merging(sd_model, token_merging_ratio):
|
||||
|
||||
print(f'token_merging_ratio = {token_merging_ratio}')
|
||||
|
||||
from ldm_patched.contrib.external_tomesd import TomePatcher
|
||||
from backend.misc.tomesd import TomePatcher
|
||||
|
||||
sd_model.forge_objects.unet = TomePatcher().patch(
|
||||
model=sd_model.forge_objects.unet,
|
||||
|
||||
@@ -8,7 +8,7 @@ import sgm.modules.diffusionmodules.discretizer
|
||||
from modules import devices, shared, prompt_parser
|
||||
from modules import torch_utils
|
||||
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
from backend import memory_management
|
||||
from modules_forge.forge_clip import move_clip_to_gpu
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
|
||||
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=model_management.text_encoder_dtype())
|
||||
devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype())
|
||||
|
||||
sdxl_conds = {
|
||||
"txt": batch,
|
||||
|
||||
Reference in New Issue
Block a user