From d51e81483b2ad4491b9b93258bae1d176b716c43 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 27 Jan 2024 22:11:14 -0800 Subject: [PATCH] sampler --- modules/sd_samplers_common.py | 12 ++++-------- modules/sd_samplers_kdiffusion.py | 16 +++++++--------- modules/sd_samplers_timesteps.py | 16 +++++++--------- modules_forge/forge_loader.py | 1 - modules_forge/forge_sampler.py | 20 +++++++++++++++++++- 5 files changed, 37 insertions(+), 28 deletions(-) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index ec551257..a442e150 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -5,7 +5,7 @@ import torch from PIL import Image from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models from modules.shared import opts, state -from ldm_patched.modules import model_management +from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup import k_diffusion.sampling @@ -177,20 +177,16 @@ def apply_refiner(cfg_denoiser, x): cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at + sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet) + with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=refiner_checkpoint_info) - refiner = sd_models.model_data.get_sd_model() - devices.torch_gc() cfg_denoiser.p.setup_conds() cfg_denoiser.update_inner_model() - inference_memory = refiner.current_controlnet_required_memory - unet_patcher = refiner.forge_objects.unet - model_management.load_models_gpu( - [unet_patcher], - unet_patcher.memory_required([x.shape[0]] * 2 + list(x.shape[1:])) + inference_memory) + sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x) return True diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 23e3ea59..91580994 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -7,7 +7,7 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback from modules.shared import opts import modules.shared as shared -import ldm_patched.modules.model_management +from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup samplers_k_diffusion = [ @@ -141,11 +141,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler): return sigmas def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - inference_memory = self.model_wrap.inner_model.current_controlnet_required_memory unet_patcher = self.model_wrap.inner_model.forge_objects.unet - ldm_patched.modules.model_management.load_models_gpu( - [unet_patcher], - unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory) + sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.current_device) self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.current_device) @@ -201,14 +198,13 @@ class KDiffusionSampler(sd_samplers_common.Sampler): self.add_infotext(p) + sampling_cleanup(unet_patcher) + return samples def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - inference_memory = self.model_wrap.inner_model.current_controlnet_required_memory unet_patcher = self.model_wrap.inner_model.forge_objects.unet - ldm_patched.modules.model_management.load_models_gpu( - [unet_patcher], - unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory) + sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.current_device) self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.current_device) @@ -256,6 +252,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler): self.add_infotext(p) + sampling_cleanup(unet_patcher) + return samples diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 8a35982f..f73cb818 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -7,7 +7,7 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback from modules.shared import opts import modules.shared as shared -import ldm_patched.modules.model_management +from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup samplers_timesteps = [ @@ -98,11 +98,8 @@ class CompVisSampler(sd_samplers_common.Sampler): return timesteps def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - inference_memory = self.model_wrap.inner_model.current_controlnet_required_memory unet_patcher = self.model_wrap.inner_model.forge_objects.unet - ldm_patched.modules.model_management.load_models_gpu( - [unet_patcher], - unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory) + sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device) @@ -146,14 +143,13 @@ class CompVisSampler(sd_samplers_common.Sampler): self.add_infotext(p) + sampling_cleanup(unet_patcher) + return samples def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - inference_memory = self.model_wrap.inner_model.current_controlnet_required_memory unet_patcher = self.model_wrap.inner_model.forge_objects.unet - ldm_patched.modules.model_management.load_models_gpu( - [unet_patcher], - unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory) + sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x) self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device) @@ -178,6 +174,8 @@ class CompVisSampler(sd_samplers_common.Sampler): self.add_infotext(p) + sampling_cleanup(unet_patcher) + return samples diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 3da2ba4e..7c29da5d 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -241,7 +241,6 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model.decode_first_stage = patched_decode_first_stage sd_model.encode_first_stage = patched_encode_first_stage sd_model.clip = sd_model.cond_stage_model - sd_model.current_controlnet_required_memory = 0 timer.record("forge finalize") sd_model.current_lora_hash = str([]) diff --git a/modules_forge/forge_sampler.py b/modules_forge/forge_sampler.py index f2e84a83..ea6ea60f 100644 --- a/modules_forge/forge_sampler.py +++ b/modules_forge/forge_sampler.py @@ -1,6 +1,7 @@ import torch from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn from ldm_patched.modules.samplers import sampling_function +from ldm_patched.modules import model_management def cond_from_a1111_to_patched_ldm(cond): @@ -72,4 +73,21 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition): return denoised -# def prepare_sampling(unet, ) +def sampling_prepare(unet, x): + B, C, H, W = x.shape + + unet_inference_memory = unet.memory_required([B * 2, C, H, W]) + additional_inference_memory = unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype()) + additional_model_patchers = unet.get_models() + + model_management.load_models_gpu( + models=[unet] + additional_model_patchers, + memory_required=unet_inference_memory + additional_inference_memory) + + return + + +def sampling_cleanup(unet): + for cnet in unet.list_controlnets(): + cnet.cleanup() + return