Update sd_samplers_cfg_denoiser.py

This commit is contained in:
lllyasviel
2024-01-25 19:52:03 -08:00
parent d3cb546cc5
commit 01c422678c

View File

@@ -65,7 +65,7 @@ class CFGDenoiser(torch.nn.Module):
def inner_model(self):
raise NotImplementedError()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in):
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
denoised_uncond = x_out[-uncond.shape[0]:]
@@ -93,8 +93,8 @@ class CFGDenoiser(torch.nn.Module):
# sanity_check = torch.allclose(cfg_result, denoised)
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": torch.zeros_like(uncond),
"uncond": torch.zeros_like(uncond), "model": model,
args = {"denoised": cfg_result, "cond": cond,
"uncond": uncond, "model": model,
"uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
@@ -265,9 +265,9 @@ class CFGDenoiser(torch.nn.Module):
if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0, sigma_in, x_in)
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0, sigma_in, x_in, tensor)
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, sigma_in, x_in)
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, sigma_in, x_in, tensor)
# Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None: