Update sd_samplers_cfg_denoiser.py

This commit is contained in:
lllyasviel
2024-01-25 19:43:06 -08:00
parent c5d26cc807
commit 54f07f1d1d

View File

@@ -65,7 +65,7 @@ class CFGDenoiser(torch.nn.Module):
def inner_model(self):
raise NotImplementedError()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x):
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in):
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
denoised_uncond = x_out[-uncond.shape[0]:]
@@ -78,10 +78,28 @@ class CFGDenoiser(torch.nn.Module):
if "sampler_cfg_function" in model_options or "sampler_post_cfg_function" in model_options:
cond_scale = float(cond_scale)
model = self.inner_model.inner_model.forge_objects.unet
x = x_in[-uncond.shape[0]:]
uncond_pred = denoised_uncond
cond_pred = ((denoised - uncond_pred) / cond_scale) + uncond_pred
a = 0
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale,
"timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model,
"model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
return denoised
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": torch.zeros_like(uncond), "uncond": torch.zeros_like(uncond), "model": model,
"uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
else:
cfg_result = denoised
return cfg_result
def combine_denoised_for_edit_model(self, x_out, cond_scale):
out_cond, out_img_cond, out_uncond = x_out.chunk(3)