mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-04 23:19:57 +00:00
i
This commit is contained in:
@@ -323,16 +323,31 @@ def stack_conds(tensors):
|
||||
return torch.stack(tensors)
|
||||
|
||||
|
||||
def stack_conds_alter(tensors, weights):
|
||||
token_count = max([x.shape[0] for x in tensors])
|
||||
for i in range(len(tensors)):
|
||||
if tensors[i].shape[0] != token_count:
|
||||
last_vector = tensors[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||
|
||||
result = 0
|
||||
full_weights = 0
|
||||
for x, w in zip(tensors, weights):
|
||||
result = result + x * float(w)
|
||||
full_weights = full_weights + float(w)
|
||||
result = result / full_weights
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
param = c.batch[0][0].schedules[0].cond
|
||||
|
||||
tensors = []
|
||||
conds_list = []
|
||||
weights = []
|
||||
|
||||
for composable_prompts in c.batch:
|
||||
conds_for_batch = []
|
||||
|
||||
for composable_prompt in composable_prompts:
|
||||
target_index = 0
|
||||
for current, entry in enumerate(composable_prompt.schedules):
|
||||
@@ -340,19 +355,17 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
target_index = current
|
||||
break
|
||||
|
||||
conds_for_batch.append((len(tensors), composable_prompt.weight))
|
||||
weights.append(composable_prompt.weight)
|
||||
tensors.append(composable_prompt.schedules[target_index].cond)
|
||||
|
||||
conds_list.append(conds_for_batch)
|
||||
|
||||
if isinstance(tensors[0], dict):
|
||||
keys = list(tensors[0].keys())
|
||||
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
||||
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
|
||||
weighted = {k: stack_conds_alter([x[k] for x in tensors], weights) for k in keys}
|
||||
weighted = DictWithShape(weighted, weighted['crossattn'].shape)
|
||||
else:
|
||||
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
||||
weighted = stack_conds_alter(tensors, weights).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
return conds_list, stacked
|
||||
return weighted
|
||||
|
||||
|
||||
re_attention = re.compile(r"""
|
||||
|
||||
@@ -100,15 +100,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||
cond = self.sampler.sampler_extra_args['cond']
|
||||
uncond = self.sampler.sampler_extra_args['uncond']
|
||||
|
||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
# If we use masks, blending between the denoised and original latent images occurs here.
|
||||
def apply_blend(current_latent):
|
||||
blended_latent = current_latent * self.nmask + self.init_latent * self.mask
|
||||
@@ -125,7 +119,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||
if self.mask_before_denoising and self.mask is not None:
|
||||
x = apply_blend(x)
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, tensor, uncond, self)
|
||||
denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
|
||||
denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params, cond_scale=cond_scale)
|
||||
|
||||
Reference in New Issue
Block a user