diff --git a/modules/devices.py b/modules/devices.py index 0321d12c..6bffe7d4 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -179,26 +179,11 @@ def manual_cast(target_dtype): def autocast(disable=False): - if disable: - return contextlib.nullcontext() - - if fp8 and device==cpu: - return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - - if fp8 and dtype_inference == torch.float32: - return manual_cast(dtype) - - if dtype == torch.float32 or dtype_inference == torch.float32: - return contextlib.nullcontext() - - if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype) - - return torch.autocast("cuda") + return contextlib.nullcontext() def without_autocast(disable=False): - return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + return contextlib.nullcontext() class NansException(Exception): diff --git a/modules/initialize.py b/modules/initialize.py index 7c1ac99e..a62efa29 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -12,6 +12,10 @@ def imports(): logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + from modules_forge.initialization import initialize_forge + initialize_forge() + startup_timer.record("initialize forge") + import torch # noqa: F401 startup_timer.record("import torch") import pytorch_lightning # noqa: F401 diff --git a/modules/processing.py b/modules/processing.py index dcc807fe..e2fb57f3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from skimage import exposure from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -627,44 +627,7 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): for i in range(batch.shape[0]): sample = decode_first_stage(model, batch[i:i + 1])[0] - - if check_for_nans: - - try: - devices.test_for_nans(sample, "vae") - except devices.NansException as e: - if shared.opts.auto_vae_precision_bfloat16: - autofix_dtype = torch.bfloat16 - autofix_dtype_text = "bfloat16" - autofix_dtype_setting = "Automatically convert VAE to bfloat16" - autofix_dtype_comment = "" - elif shared.opts.auto_vae_precision: - autofix_dtype = torch.float32 - autofix_dtype_text = "32-bit float" - autofix_dtype_setting = "Automatically revert VAE to 32-bit floats" - autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag." - else: - raise e - - if devices.dtype_vae == autofix_dtype: - raise e - - errors.print_error_explanation( - "A tensor with all NaNs was produced in VAE.\n" - f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n" - f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}" - ) - - devices.dtype_vae = autofix_dtype - model.first_stage_model.to(devices.dtype_vae) - batch = batch.to(devices.dtype_vae) - - sample = decode_first_stage(model, batch[i:i + 1])[0] - - if target_device is not None: - sample = sample.to(target_device) - - samples.append(sample) + samples.append(sample.to(target_device)) return samples @@ -940,8 +903,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): - samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if p.scripts is not None: ps = scripts.PostSampleArgs(samples_ddim) @@ -960,9 +922,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: del samples_ddim - if lowvram.is_enabled(shared.sd_model): - lowvram.send_everything_to_cpu() - devices.torch_gc() state.nextjob() @@ -1255,7 +1214,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): image = np.array(self.firstpass_image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) image = torch.from_numpy(np.expand_dims(image, axis=0)) - image = image.to(shared.device, dtype=devices.dtype_vae) + image = image.to(shared.device, dtype=torch.float32) if opts.sd_vae_encode_method != 'Full': self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method @@ -1339,7 +1298,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): batch_images.append(image) decoded_samples = torch.from_numpy(np.array(batch_images)) - decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae) + decoded_samples = decoded_samples.to(shared.device, dtype=torch.float32) if opts.sd_vae_encode_method != 'Full': self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method @@ -1444,7 +1403,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if shared.opts.hires_fix_use_firstpass_conds: self.calculate_hr_conds() - elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded + elif shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded with devices.autocast(): extra_networks.activate(self, self.hr_extra_network_data) @@ -1631,7 +1590,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") image = torch.from_numpy(batch_images) - image = image.to(shared.device, dtype=devices.dtype_vae) + image = image.to(shared.device, dtype=torch.float32) if opts.sd_vae_encode_method != 'Full': self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index cba13455..2afd1b7c 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -276,6 +276,12 @@ class DictWithShape(dict): def shape(self): return self["crossattn"].shape + def to(self, *args, **kwargs): + for k in self.keys(): + if isinstance(self[k], torch.Tensor): + self[k] = self[k].to(*args, **kwargs) + return self + def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): param = c[0][0].cond diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 8f29057a..426f40b6 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -3,6 +3,7 @@ from collections import namedtuple import torch +import ldm_patched.modules.model_management as model_management from modules import prompt_parser, devices, sd_hijack from modules.shared import opts @@ -210,6 +211,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" """ + if hasattr(self.wrapped, 'patcher'): + model_management.load_model_gpu(self.wrapped.patcher) + if opts.use_old_emphasis_implementation: import modules.sd_hijack_clip_old return modules.sd_hijack_clip_old.forward_old(self, texts) diff --git a/modules/sd_models.py b/modules/sd_models.py index 2c045771..972cc7d9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -13,10 +13,16 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, sd_hijack, patches from modules.timer import Timer import tomesd import numpy as np +from modules_forge import forge_loader +import modules_forge.ops as forge_ops +from ldm_patched.modules.ops import manual_cast +from ldm_patched.modules import model_management as model_management +import ldm_patched.modules.model_patcher + model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) @@ -366,10 +372,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if devices.fp8: - # prevent model to load state dict in fp8 - model.half() - if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title @@ -379,13 +381,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sdxl = hasattr(model, 'conditioner') model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 - model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() + model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in model.state_dict().keys() if model.is_sdxl: sd_models_xl.extend_sdxl(model) - if model.is_ssd: - sd_hijack.model_hijack.convert_sdxl_to_ssd(model) - if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model checkpoints_loaded[checkpoint_info] = state_dict.copy() @@ -395,65 +394,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer del state_dict - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) - timer.record("apply channels_last") - - if shared.cmd_opts.no_half: - model.float() - model.alphas_cumprod_original = model.alphas_cumprod - devices.dtype_unet = torch.float32 - timer.record("apply float()") - else: - vae = model.first_stage_model - depth_model = getattr(model, 'depth_model', None) - - # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: - model.first_stage_model = None - # with --upcast-sampling, don't convert the depth model weights to float16 - if shared.cmd_opts.upcast_sampling and depth_model: - model.depth_model = None - - alphas_cumprod = model.alphas_cumprod - model.alphas_cumprod = None - model.half() - model.alphas_cumprod = alphas_cumprod - model.alphas_cumprod_original = alphas_cumprod - model.first_stage_model = vae - if depth_model: - model.depth_model = depth_model - - devices.dtype_unet = torch.float16 - timer.record("apply half()") - - for module in model.modules(): - if hasattr(module, 'fp16_weight'): - del module.fp16_weight - if hasattr(module, 'fp16_bias'): - del module.fp16_bias - - if check_fp8(model): - devices.fp8 = True - first_stage = model.first_stage_model - model.first_stage_model = None - for module in model.modules(): - if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): - if shared.opts.cache_fp16_weight: - module.fp16_weight = module.weight.data.clone().cpu().half() - if module.bias is not None: - module.fp16_bias = module.bias.data.clone().cpu().half() - module.to(torch.float8_e4m3fn) - model.first_stage_model = first_stage - timer.record("apply fp8") - else: - devices.fp8 = False - - devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 - - model.first_stage_model.to(devices.dtype_vae) - timer.record("apply dtype to VAE") - # clean up cache if limit is reached while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) @@ -614,34 +554,6 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) -def send_model_to_cpu(m): - if m.lowvram: - lowvram.send_everything_to_cpu() - else: - m.to(devices.cpu) - - devices.torch_gc() - - -def model_target_device(m): - if lowvram.is_needed(m): - return devices.cpu - else: - return devices.device - - -def send_model_to_device(m): - lowvram.apply(m) - - if not m.lowvram: - m.to(shared.device) - - -def send_model_to_trash(m): - m.to(device="meta") - devices.torch_gc() - - def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -649,7 +561,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer = Timer() if model_data.sd_model: - send_model_to_trash(model_data.sd_model) model_data.sd_model = None devices.torch_gc() @@ -670,12 +581,21 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("load config") + if hasattr(sd_config.model.params, 'network_config'): + sd_config.model.params.network_config.target = 'modules_forge.forge_loader.FakeObject' + + if hasattr(sd_config.model.params, 'unet_config'): + sd_config.model.params.unet_config.target = 'modules_forge.forge_loader.FakeObject' + + if hasattr(sd_config.model.params, 'first_stage_config'): + sd_config.model.params.first_stage_config.target = 'modules_forge.forge_loader.FakeObject' + print(f"Creating model from config: {checkpoint_config}") sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): - with sd_disable_initialization.InitializeOnMeta(): + with forge_ops.use_patched_ops(manual_cast): sd_model = instantiate_from_config(sd_config.model) except Exception as e: @@ -684,28 +604,45 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) - with sd_disable_initialization.InitializeOnMeta(): + with forge_ops.use_patched_ops(manual_cast): sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config timer.record("create model") - if shared.cmd_opts.no_half: - weight_dtype_conversion = None - else: - weight_dtype_conversion = { - 'first_stage_model': None, - 'alphas_cumprod': None, - '': torch.float16, - } + state_dict_for_a1111 = {k: v for k, v in state_dict.items() if not k.startswith('model.diffusion_model.') and not k.startswith('first_stage_model.')} + state_dict_for_forge = {k: v for k, v in state_dict.items()} + del state_dict - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): - load_model_weights(sd_model, checkpoint_info, state_dict, timer) + unet_patcher, vae_patcher = forge_loader.load_unet_and_vae(state_dict_for_forge) + sd_model.first_stage_model = vae_patcher.first_stage_model + sd_model.model.diffusion_model = unet_patcher.model.diffusion_model + sd_model.unet_patcher = unet_patcher + sd_model.model.diffusion_model.patcher = unet_patcher + sd_model.vae_patcher = vae_patcher + sd_model.first_stage_model.patcher = vae_patcher + timer.record("create unet patcher") + del state_dict_for_forge + + load_model_weights(sd_model, checkpoint_info, state_dict_for_a1111, timer) + del state_dict_for_a1111 timer.record("load weights from state dict") - send_model_to_device(sd_model) - timer.record("move model to device") + current_clip = sd_model.conditioner if hasattr(sd_model, 'conditioner') else sd_model.cond_stage_model + clip_load_device = model_management.text_encoder_device() + clip_offload_device = model_management.text_encoder_offload_device() + clip_dtype = model_management.text_encoder_dtype() + + current_clip.to(clip_dtype) + clip_patcher = ldm_patched.modules.model_patcher.ModelPatcher( + current_clip, + load_device=clip_load_device, + offload_device=clip_offload_device + ) + sd_model.clip_patcher = clip_patcher + current_clip.patcher = clip_patcher + timer.record("create clip patcher") sd_hijack.model_hijack.hijack(sd_model) @@ -752,17 +689,8 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") model_data.loaded_sd_models.pop() - send_model_to_trash(loaded_model) - timer.record("send model to trash") - - if shared.opts.sd_checkpoints_keep_in_cpu: - send_model_to_cpu(sd_model) - timer.record("send model to cpu") if already_loaded is not None: - send_model_to_device(already_loaded) - timer.record("send model to device") - model_data.set_sd_model(already_loaded, already_loaded=True) if not SkipWritingToConfig.skip: @@ -816,7 +744,6 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): if sd_model is not None: sd_unet.apply_unet("None") - send_model_to_cpu(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model) state_dict = get_checkpoint_state_dict(checkpoint_info, timer) @@ -827,7 +754,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): if sd_model is None or checkpoint_config != sd_model.used_config: if sd_model is not None: - send_model_to_trash(sd_model) + sd_model = None load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model @@ -857,12 +784,6 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): return sd_model -def unload_model_weights(sd_model=None, info=None): - send_model_to_cpu(sd_model or shared.sd_model) - - return sd_model - - def apply_token_merging(sd_model, token_merging_ratio): """ Applies speed and memory optimizations from tomesd. diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 0de17af3..1418348a 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -8,8 +8,12 @@ import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser from modules import torch_utils +import ldm_patched.modules.model_management as model_management + def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): + model_management.load_model_gpu(self.clip_patcher) + for embedder in self.conditioner.embedders: embedder.ucg_rate = 0.0 @@ -18,7 +22,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: is_negative_prompt = getattr(batch, 'is_negative_prompt', False) aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score - devices_args = dict(device=devices.device, dtype=devices.dtype) + devices_args = dict(device=self.clip_patcher.current_device, dtype=model_management.text_encoder_dtype()) sdxl_conds = { "txt": batch, @@ -35,11 +39,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): - sd = self.model.state_dict() - diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input is not None: - if diffusion_model_input.shape[1] == 9: - x = torch.cat([x] + cond['c_concat'], dim=1) + if self.model.diffusion_model.in_channels == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 6d76aa96..47d8f644 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -173,6 +173,15 @@ class CFGDenoiser(torch.nn.Module): uncond = pad_cond(uncond, num_repeats, empty) self.padded_cond_uncond = True + unet_dtype = self.inner_model.inner_model.unet_patcher.model.model_config.unet_config['dtype'] + x_input_dtype = x_in.dtype + + x_in = x_in.to(unet_dtype) + sigma_in = sigma_in.to(unet_dtype) + image_cond_in = image_cond_in.to(unet_dtype) + tensor = tensor.to(unet_dtype) + uncond = uncond.to(unet_dtype) + if tensor.shape[1] == uncond.shape[1] or skip_uncond: if is_edit_model: cond_in = catenate_conds([tensor, uncond, uncond]) @@ -211,6 +220,8 @@ class CFGDenoiser(torch.nn.Module): fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be + x_out = x_out.to(x_input_dtype) + denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) cfg_denoised_callback(denoised_params) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 58efcad2..c45db906 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -39,9 +39,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None): if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt): approximation = approximation_indexes.get(opts.show_progress_type, 0) - - from modules import lowvram - if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full: + if approximation == 0: approximation = 1 if approximation == 2: @@ -54,8 +52,8 @@ def samples_to_images_tensor(sample, approximation=None, model=None): else: if model is None: model = shared.sd_model - with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 - x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) + sample = model.unet_patcher.model.model_config.latent_format.process_out(sample) + x_sample = model.vae_patcher.decode(sample).movedim(-1, 1) * 2.0 - 1.0 return x_sample @@ -71,7 +69,6 @@ def single_sample_to_image(sample, approximation=None): def decode_first_stage(model, x): - x = x.to(devices.dtype_vae) approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0) return samples_to_images_tensor(x, approx_index, model) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 8a8c87e0..5e115486 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -7,6 +7,8 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback from modules.shared import opts import modules.shared as shared +import ldm_patched.modules.model_management + samplers_k_diffusion = [ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), @@ -139,6 +141,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler): return sigmas def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + inference_memory = 0 + unet_patcher = self.model_wrap.inner_model.unet_patcher + ldm_patched.modules.model_management.load_models_gpu( + [unet_patcher], + unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory) + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) sigmas = self.get_sigmas(p, steps) @@ -193,6 +201,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler): return samples def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + inference_memory = 0 + unet_patcher = self.model_wrap.inner_model.unet_patcher + ldm_patched.modules.model_management.load_models_gpu( + [unet_patcher], + unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory) + steps = steps or p.steps sigmas = self.get_sigmas(p, steps) diff --git a/modules/sd_samplers_lcm.py b/modules/sd_samplers_lcm.py index 59839b72..45df4db9 100644 --- a/modules/sd_samplers_lcm.py +++ b/modules/sd_samplers_lcm.py @@ -34,7 +34,7 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser): def sigma_to_t(self, sigma, quantize=None): log_sigma = sigma.log() - dists = log_sigma - self.log_sigmas[:, None] + dists = log_sigma - self.log_sigmas.to(sigma)[:, None] return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 777dd8d0..6d10ba61 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -7,6 +7,8 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback from modules.shared import opts import modules.shared as shared +import ldm_patched.modules.model_management + samplers_timesteps = [ ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}), @@ -95,6 +97,12 @@ class CompVisSampler(sd_samplers_common.Sampler): return timesteps def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + inference_memory = 0 + unet_patcher = self.model_wrap.inner_model.unet_patcher + ldm_patched.modules.model_management.load_models_gpu( + [unet_patcher], + unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory) + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) timesteps = self.get_timesteps(p, steps) @@ -139,6 +147,12 @@ class CompVisSampler(sd_samplers_common.Sampler): return samples def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + inference_memory = 0 + unet_patcher = self.model_wrap.inner_model.unet_patcher + ldm_patched.modules.model_management.load_models_gpu( + [unet_patcher], + unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory) + steps = steps or p.steps timesteps = self.get_timesteps(p, steps) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 43687e48..9f37ed6c 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -2,7 +2,7 @@ import os import collections from dataclasses import dataclass -from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes +from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, sd_hijack, hashes import glob from copy import deepcopy @@ -237,7 +237,6 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): # don't call this from outside def _load_vae_dict(model, vae_dict_1): model.first_stage_model.load_state_dict(vae_dict_1) - model.first_stage_model.to(devices.dtype_vae) def clear_loaded_vae(): @@ -263,20 +262,12 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): if loaded_vae_file == vae_file: return - if sd_model.lowvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(sd_model) load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) - if not sd_model.lowvram: - sd_model.to(devices.device) - script_callbacks.model_loaded_callback(sd_model) print("VAE weights loaded.") diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py new file mode 100644 index 00000000..582d2e27 --- /dev/null +++ b/modules_forge/forge_loader.py @@ -0,0 +1,42 @@ +import torch + +from ldm_patched.modules import model_management +from ldm_patched.modules import model_detection + +from ldm_patched.modules.sd import VAE +import ldm_patched.modules.model_patcher +import ldm_patched.modules.utils + + +def load_unet_and_vae(sd): + parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.") + unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + + model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) + model_config.set_manual_cast(manual_cast_dtype) + + if model_config is None: + raise RuntimeError("ERROR: Could not detect model type of") + + initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) + model = model_config.get_model(sd, "model.diffusion_model.", device=initial_load_device) + model.load_model_weights(sd, "model.diffusion_model.") + + model_patcher = ldm_patched.modules.model_patcher.ModelPatcher(model, + load_device=load_device, + offload_device=model_management.unet_offload_device(), + current_device=initial_load_device) + + vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = model_config.process_vae_state_dict(vae_sd) + vae_patcher = VAE(sd=vae_sd) + + return model_patcher, vae_patcher + + +class FakeObject(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + return diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py new file mode 100644 index 00000000..176a78f6 --- /dev/null +++ b/modules_forge/initialization.py @@ -0,0 +1,66 @@ +import argparse + + +def initialize_forge(): + parser = argparse.ArgumentParser() + parser.add_argument("--no-vram", action="store_true") + parser.add_argument("--normal-vram", action="store_true") + parser.add_argument("--high-vram", action="store_true") + parser.add_argument("--always-vram", action="store_true") + vram_args = vars(parser.parse_known_args()[0]) + + parser = argparse.ArgumentParser() + attn_group = parser.add_mutually_exclusive_group() + attn_group.add_argument("--attention-split", action="store_true") + attn_group.add_argument("--attention-quad", action="store_true") + attn_group.add_argument("--attention-pytorch", action="store_true") + parser.add_argument("--disable-xformers", action="store_true") + fpte_group = parser.add_mutually_exclusive_group() + fpte_group.add_argument("--clip-in-fp8-e4m3fn", action="store_true") + fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true") + fpte_group.add_argument("--clip-in-fp16", action="store_true") + fpte_group.add_argument("--clip-in-fp32", action="store_true") + fp_group = parser.add_mutually_exclusive_group() + fp_group.add_argument("--all-in-fp32", action="store_true") + fp_group.add_argument("--all-in-fp16", action="store_true") + fpunet_group = parser.add_mutually_exclusive_group() + fpunet_group.add_argument("--unet-in-bf16", action="store_true") + fpunet_group.add_argument("--unet-in-fp16", action="store_true") + fpunet_group.add_argument("--unet-in-fp8-e4m3fn", action="store_true") + fpunet_group.add_argument("--unet-in-fp8-e5m2", action="store_true") + fpvae_group = parser.add_mutually_exclusive_group() + fpvae_group.add_argument("--vae-in-fp16", action="store_true") + fpvae_group.add_argument("--vae-in-fp32", action="store_true") + fpvae_group.add_argument("--vae-in-bf16", action="store_true") + other_args = vars(parser.parse_known_args()[0]) + + from ldm_patched.modules.args_parser import args + + args.always_cpu = False + args.always_gpu = False + args.always_high_vram = False + args.always_low_vram = False + args.always_no_vram = False + args.always_offload_from_vram = True + args.async_cuda_allocation = False + args.disable_async_cuda_allocation = True + + if vram_args['no_vram']: + args.always_cpu = True + + if vram_args['always_vram']: + args.always_gpu = True + + if vram_args['high_vram']: + args.always_offload_from_vram = False + + for k, v in other_args.items(): + setattr(args, k, v) + + import ldm_patched.modules.model_management as model_management + import torch + + device = model_management.get_torch_device() + torch.zeros((1, 1)).to(device, torch.float32) + model_management.soft_empty_cache() + return diff --git a/modules_forge/ops.py b/modules_forge/ops.py new file mode 100644 index 00000000..ee0e7756 --- /dev/null +++ b/modules_forge/ops.py @@ -0,0 +1,19 @@ +import torch +import contextlib + + +@contextlib.contextmanager +def use_patched_ops(operations): + op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm'] + backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} + + try: + for op_name in op_names: + setattr(torch.nn, op_name, getattr(operations, op_name)) + + yield + + finally: + for op_name in op_names: + setattr(torch.nn, op_name, backups[op_name]) + return