Update sd_samplers_cfg_denoiser.py

This commit is contained in:
lllyasviel
2024-01-25 19:00:42 -08:00
parent 3c25b7f892
commit 161a0938d7

View File

@@ -185,18 +185,23 @@ class CFGDenoiser(torch.nn.Module):
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = catenate_conds([tensor, uncond, uncond])
cond_or_uncond = [0] * int(tensor.shape[0]) + [1] * int(uncond.shape[0]) + [1] * int(uncond.shape[0])
elif skip_uncond:
cond_in = tensor
cond_or_uncond = [0] * int(tensor.shape[0])
else:
cond_in = catenate_conds([tensor, uncond])
cond_or_uncond = [0] * int(tensor.shape[0]) + [1] * int(uncond.shape[0])
if shared.opts.batch_cond_uncond:
self.inner_model.inner_model.cond_or_uncond = cond_or_uncond
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
self.inner_model.inner_model.cond_or_uncond = cond_or_uncond[a:b]
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
@@ -210,9 +215,11 @@ class CFGDenoiser(torch.nn.Module):
else:
c_crossattn = torch.cat([tensor[a:b]], uncond)
self.inner_model.inner_model.cond_or_uncond = [0] * int(sigma_in[a:b].shape[0])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
if not skip_uncond:
self.inner_model.inner_model.cond_or_uncond = [1] * int(sigma_in[-uncond.shape[0]:].shape[0])
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
denoised_image_indexes = [x[0][0] for x in conds_list]