mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 10:41:25 +00:00
add_sampler_pre_cfg_function
This commit is contained in:
@@ -55,11 +55,17 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
|
|||||||
|
|
||||||
unet = process.sd_model.forge_objects.unet.clone()
|
unet = process.sd_model.forge_objects.unet.clone()
|
||||||
|
|
||||||
|
def pre_cfg(model, c, uc, x, timestep, model_options):
|
||||||
|
noisy_latent = latent_image.to(x) + timestep.to(x) * torch.randn_like(latent_image).to(x)
|
||||||
|
x = x * latent_mask.to(x) + noisy_latent.to(x) * (1.0 - latent_mask.to(x))
|
||||||
|
return model, c, uc, x, timestep, model_options
|
||||||
|
|
||||||
def post_cfg(args):
|
def post_cfg(args):
|
||||||
denoised = args['denoised']
|
denoised = args['denoised']
|
||||||
denoised = denoised * latent_mask.to(denoised) + latent_image.to(denoised) * (1.0 - latent_mask.to(denoised))
|
denoised = denoised * latent_mask.to(denoised) + latent_image.to(denoised) * (1.0 - latent_mask.to(denoised))
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
unet.add_sampler_pre_cfg_function(pre_cfg)
|
||||||
unet.set_model_sampler_post_cfg_function(post_cfg)
|
unet.set_model_sampler_post_cfg_function(post_cfg)
|
||||||
|
|
||||||
process.sd_model.forge_objects.unet = unet
|
process.sd_model.forge_objects.unet = unet
|
||||||
|
|||||||
@@ -276,6 +276,9 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
|||||||
else:
|
else:
|
||||||
uncond_ = uncond
|
uncond_ = uncond
|
||||||
|
|
||||||
|
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||||
|
model, cond, uncond_, x, timestep, model_options = fn(model, cond, uncond_, x, timestep, model_options)
|
||||||
|
|
||||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||||
|
|||||||
@@ -102,6 +102,10 @@ class UnetPatcher(ModelPatcher):
|
|||||||
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
|
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def add_sampler_pre_cfg_function(self, modifier, ensure_uniqueness=False):
|
||||||
|
self.append_model_option('sampler_pre_cfg_function', modifier, ensure_uniqueness)
|
||||||
|
return
|
||||||
|
|
||||||
def set_memory_peak_estimation_modifier(self, modifier):
|
def set_memory_peak_estimation_modifier(self, modifier):
|
||||||
self.model_options['memory_peak_estimation_modifier'] = modifier
|
self.model_options['memory_peak_estimation_modifier'] = modifier
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user