From 7b1a1e510fc6bbead3ec3dfa1219c4835fa9bfcc Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 27 Jan 2024 13:53:18 -0800 Subject: [PATCH] i --- modules/prompt_parser.py | 33 ++++++++++++++++++++--------- modules/sd_samplers_cfg_denoiser.py | 10 ++------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 2afd1b7c..0c86c356 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -323,16 +323,31 @@ def stack_conds(tensors): return torch.stack(tensors) +def stack_conds_alter(tensors, weights): + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + + result = 0 + full_weights = 0 + for x, w in zip(tensors, weights): + result = result + x * float(w) + full_weights = full_weights + float(w) + result = result / full_weights + + return result + def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): param = c.batch[0][0].schedules[0].cond tensors = [] - conds_list = [] + weights = [] for composable_prompts in c.batch: - conds_for_batch = [] - for composable_prompt in composable_prompts: target_index = 0 for current, entry in enumerate(composable_prompt.schedules): @@ -340,19 +355,17 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): target_index = current break - conds_for_batch.append((len(tensors), composable_prompt.weight)) + weights.append(composable_prompt.weight) tensors.append(composable_prompt.schedules[target_index].cond) - conds_list.append(conds_for_batch) - if isinstance(tensors[0], dict): keys = list(tensors[0].keys()) - stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys} - stacked = DictWithShape(stacked, stacked['crossattn'].shape) + weighted = {k: stack_conds_alter([x[k] for x in tensors], weights) for k in keys} + weighted = DictWithShape(weighted, weighted['crossattn'].shape) else: - stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype) + weighted = stack_conds_alter(tensors, weights).to(device=param.device, dtype=param.dtype) - return conds_list, stacked + return weighted re_attention = re.compile(r""" diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 8ad9fbfe..521d016c 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -100,15 +100,9 @@ class CFGDenoiser(torch.nn.Module): cond = self.sampler.sampler_extra_args['cond'] uncond = self.sampler.sampler_extra_args['uncond'] - # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, - # so is_edit_model is set to False to support AND composition. - is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) + cond = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" - # If we use masks, blending between the denoised and original latent images occurs here. def apply_blend(current_latent): blended_latent = current_latent * self.nmask + self.init_latent * self.mask @@ -125,7 +119,7 @@ class CFGDenoiser(torch.nn.Module): if self.mask_before_denoising and self.mask is not None: x = apply_blend(x) - denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, tensor, uncond, self) + denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self) cfg_denoiser_callback(denoiser_params) denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params, cond_scale=cond_scale)