Update prompt_parser.py

This commit is contained in:
lllyasviel
2024-01-27 14:13:39 -08:00
parent a3b6396037
commit 846bf76ce5

View File

@@ -338,16 +338,18 @@ def stack_conds_alter(tensors, weights):
full_weights = full_weights + float(w)
result = result / full_weights
return result[None].contiguous().clone()
return result
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond
tensors = []
weights = []
results = []
for composable_prompts in c.batch:
tensors = []
weights = []
for composable_prompt in composable_prompts:
target_index = 0
for current, entry in enumerate(composable_prompt.schedules):
@@ -358,14 +360,21 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
weights.append(composable_prompt.weight)
tensors.append(composable_prompt.schedules[target_index].cond)
if isinstance(tensors[0], dict):
keys = list(tensors[0].keys())
weighted = {k: stack_conds_alter([x[k] for x in tensors], weights) for k in keys}
weighted = DictWithShape(weighted, weighted['crossattn'].shape)
else:
weighted = stack_conds_alter(tensors, weights).to(device=param.device, dtype=param.dtype)
if isinstance(tensors[0], dict):
weighted = {k: stack_conds_alter([x[k] for x in tensors], weights) for k in tensors[0].keys()}
else:
weighted = stack_conds_alter(tensors, weights)
return weighted
results.append(weighted)
if isinstance(results[0], dict):
results = {k: torch.stack([x[k] for x in results])
for k in results[0].keys()}
results = DictWithShape(results, results['crossattn'].shape)
else:
results = torch.stack(results).to(device=param.device, dtype=param.dtype)
return results
re_attention = re.compile(r"""