From 846bf76ce5c7d885cbc2019f10ad5aa9c16c4925 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 27 Jan 2024 14:13:39 -0800 Subject: [PATCH] Update prompt_parser.py --- modules/prompt_parser.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index a7604e97..b7f0e260 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -338,16 +338,18 @@ def stack_conds_alter(tensors, weights): full_weights = full_weights + float(w) result = result / full_weights - return result[None].contiguous().clone() + return result def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): param = c.batch[0][0].schedules[0].cond - tensors = [] - weights = [] + results = [] for composable_prompts in c.batch: + tensors = [] + weights = [] + for composable_prompt in composable_prompts: target_index = 0 for current, entry in enumerate(composable_prompt.schedules): @@ -358,14 +360,21 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): weights.append(composable_prompt.weight) tensors.append(composable_prompt.schedules[target_index].cond) - if isinstance(tensors[0], dict): - keys = list(tensors[0].keys()) - weighted = {k: stack_conds_alter([x[k] for x in tensors], weights) for k in keys} - weighted = DictWithShape(weighted, weighted['crossattn'].shape) - else: - weighted = stack_conds_alter(tensors, weights).to(device=param.device, dtype=param.dtype) + if isinstance(tensors[0], dict): + weighted = {k: stack_conds_alter([x[k] for x in tensors], weights) for k in tensors[0].keys()} + else: + weighted = stack_conds_alter(tensors, weights) - return weighted + results.append(weighted) + + if isinstance(results[0], dict): + results = {k: torch.stack([x[k] for x in results]) + for k in results[0].keys()} + results = DictWithShape(results, results['crossattn'].shape) + else: + results = torch.stack(results).to(device=param.device, dtype=param.dtype) + + return results re_attention = re.compile(r"""