From 161a0938d72b258f3f14f4b46cd13bb15799fc64 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 19:00:42 -0800 Subject: [PATCH] Update sd_samplers_cfg_denoiser.py --- modules/sd_samplers_cfg_denoiser.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index ac196175..09140a24 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -185,18 +185,23 @@ class CFGDenoiser(torch.nn.Module): if tensor.shape[1] == uncond.shape[1] or skip_uncond: if is_edit_model: cond_in = catenate_conds([tensor, uncond, uncond]) + cond_or_uncond = [0] * int(tensor.shape[0]) + [1] * int(uncond.shape[0]) + [1] * int(uncond.shape[0]) elif skip_uncond: cond_in = tensor + cond_or_uncond = [0] * int(tensor.shape[0]) else: cond_in = catenate_conds([tensor, uncond]) + cond_or_uncond = [0] * int(tensor.shape[0]) + [1] * int(uncond.shape[0]) if shared.opts.batch_cond_uncond: + self.inner_model.inner_model.cond_or_uncond = cond_or_uncond x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size + self.inner_model.inner_model.cond_or_uncond = cond_or_uncond[a:b] x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) @@ -210,9 +215,11 @@ class CFGDenoiser(torch.nn.Module): else: c_crossattn = torch.cat([tensor[a:b]], uncond) + self.inner_model.inner_model.cond_or_uncond = [0] * int(sigma_in[a:b].shape[0]) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) if not skip_uncond: + self.inner_model.inner_model.cond_or_uncond = [1] * int(sigma_in[-uncond.shape[0]:].shape[0]) x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) denoised_image_indexes = [x[0][0] for x in conds_list]