fix inpaint batch dim align #94

This commit is contained in:
lllyasviel
2024-02-06 22:57:53 -08:00
parent 65f9c7d442
commit 4ea4a92fe9
2 changed files with 2 additions and 2 deletions

View File

@@ -56,7 +56,7 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
unet = process.sd_model.forge_objects.unet.clone()
def pre_cfg(model, c, uc, x, timestep, model_options):
noisy_latent = latent_image.to(x) + timestep.to(x) * torch.randn_like(latent_image).to(x)
noisy_latent = latent_image.to(x) + timestep[:, None, None, None].to(x) * torch.randn_like(latent_image).to(x)
x = x * latent_mask.to(x) + noisy_latent.to(x) * (1.0 - latent_mask.to(x))
return model, c, uc, x, timestep, model_options

View File

@@ -173,7 +173,7 @@ class CFGDenoiser(torch.nn.Module):
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
if self.mask is not None:
noisy_initial_latent = self.init_latent + sigma * torch.randn_like(self.init_latent).to(self.init_latent)
noisy_initial_latent = self.init_latent + sigma[:, None, None, None] * torch.randn_like(self.init_latent).to(self.init_latent)
x = x * self.nmask + noisy_initial_latent * self.mask
denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self)