From 8de4896bd6cbb7b45211d274699b73d94a20e509 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Mon, 5 Feb 2024 13:44:31 -0800 Subject: [PATCH] fix inpaint formulation --- modules/sd_samplers_cfg_denoiser.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index e74d7140..0a243630 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -169,21 +169,9 @@ class CFGDenoiser(torch.nn.Module): cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - # If we use masks, blending between the denoised and original latent images occurs here. - def apply_blend(current_latent): - blended_latent = current_latent * self.nmask + self.init_latent * self.mask - - if self.p.scripts is not None: - from modules import scripts - mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma) - self.p.scripts.on_mask_blend(self.p, mba) - blended_latent = mba.blended_latent - - return blended_latent - - # # Blend in the original latents (before, wrong method) - # if self.mask is not None: - # x = apply_blend(x) + if self.mask is not None: + noisy_initial_latent = self.init_latent + sigma * torch.randn_like(self.init_latent).to(self.init_latent) + x = x * self.nmask + noisy_initial_latent * self.mask denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self) cfg_denoiser_callback(denoiser_params) @@ -191,9 +179,8 @@ class CFGDenoiser(torch.nn.Module): denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params, cond_scale=cond_scale, cond_composition=cond_composition) - # Blend in the original latents (after, correct method) if self.mask is not None: - denoised = apply_blend(denoised) + denoised = denoised * self.nmask + self.init_latent * self.mask preview = self.sampler.last_latent = denoised sd_samplers_common.store_latent(preview)