mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 08:59:51 +00:00
sampler
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
||||
from PIL import Image
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||
from modules.shared import opts, state
|
||||
from ldm_patched.modules import model_management
|
||||
from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup
|
||||
import k_diffusion.sampling
|
||||
|
||||
|
||||
@@ -177,20 +177,16 @@ def apply_refiner(cfg_denoiser, x):
|
||||
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||
|
||||
sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||
|
||||
refiner = sd_models.model_data.get_sd_model()
|
||||
|
||||
devices.torch_gc()
|
||||
cfg_denoiser.p.setup_conds()
|
||||
cfg_denoiser.update_inner_model()
|
||||
|
||||
inference_memory = refiner.current_controlnet_required_memory
|
||||
unet_patcher = refiner.forge_objects.unet
|
||||
model_management.load_models_gpu(
|
||||
[unet_patcher],
|
||||
unet_patcher.memory_required([x.shape[0]] * 2 + list(x.shape[1:])) + inference_memory)
|
||||
sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
import ldm_patched.modules.model_management
|
||||
from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup
|
||||
|
||||
|
||||
samplers_k_diffusion = [
|
||||
@@ -141,11 +141,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
return sigmas
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
inference_memory = self.model_wrap.inner_model.current_controlnet_required_memory
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
ldm_patched.modules.model_management.load_models_gpu(
|
||||
[unet_patcher],
|
||||
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.current_device)
|
||||
self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.current_device)
|
||||
@@ -201,14 +198,13 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
sampling_cleanup(unet_patcher)
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
inference_memory = self.model_wrap.inner_model.current_controlnet_required_memory
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
ldm_patched.modules.model_management.load_models_gpu(
|
||||
[unet_patcher],
|
||||
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.current_device)
|
||||
self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.current_device)
|
||||
@@ -256,6 +252,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
sampling_cleanup(unet_patcher)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
import ldm_patched.modules.model_management
|
||||
from modules_forge.forge_sampler import sampling_prepare, sampling_cleanup
|
||||
|
||||
|
||||
samplers_timesteps = [
|
||||
@@ -98,11 +98,8 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
return timesteps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
inference_memory = self.model_wrap.inner_model.current_controlnet_required_memory
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
ldm_patched.modules.model_management.load_models_gpu(
|
||||
[unet_patcher],
|
||||
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device)
|
||||
|
||||
@@ -146,14 +143,13 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
sampling_cleanup(unet_patcher)
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
inference_memory = self.model_wrap.inner_model.current_controlnet_required_memory
|
||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||
ldm_patched.modules.model_management.load_models_gpu(
|
||||
[unet_patcher],
|
||||
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
|
||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||
|
||||
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.current_device)
|
||||
|
||||
@@ -178,6 +174,8 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
|
||||
self.add_infotext(p)
|
||||
|
||||
sampling_cleanup(unet_patcher)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
|
||||
@@ -241,7 +241,6 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
||||
sd_model.decode_first_stage = patched_decode_first_stage
|
||||
sd_model.encode_first_stage = patched_encode_first_stage
|
||||
sd_model.clip = sd_model.cond_stage_model
|
||||
sd_model.current_controlnet_required_memory = 0
|
||||
timer.record("forge finalize")
|
||||
|
||||
sd_model.current_lora_hash = str([])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn
|
||||
from ldm_patched.modules.samplers import sampling_function
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
def cond_from_a1111_to_patched_ldm(cond):
|
||||
@@ -72,4 +73,21 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition):
|
||||
return denoised
|
||||
|
||||
|
||||
# def prepare_sampling(unet, )
|
||||
def sampling_prepare(unet, x):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
unet_inference_memory = unet.memory_required([B * 2, C, H, W])
|
||||
additional_inference_memory = unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
|
||||
additional_model_patchers = unet.get_models()
|
||||
|
||||
model_management.load_models_gpu(
|
||||
models=[unet] + additional_model_patchers,
|
||||
memory_required=unet_inference_memory + additional_inference_memory)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def sampling_cleanup(unet):
|
||||
for cnet in unet.list_controlnets():
|
||||
cnet.cleanup()
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user