This commit is contained in:
lllyasviel
2024-01-27 18:38:11 -08:00
parent 5d3071331a
commit ab1814bde9
3 changed files with 34 additions and 37 deletions

View File

@@ -323,32 +323,15 @@ def stack_conds(tensors):
return torch.stack(tensors)
def stack_conds_alter(tensors, weights):
token_count = max([x.shape[0] for x in tensors])
for i in range(len(tensors)):
if tensors[i].shape[0] != token_count:
last_vector = tensors[i][-1:]
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
result = 0
full_weights = 0
for x, w in zip(tensors, weights):
result = result + x * float(w)
full_weights = full_weights + float(w)
result = result / full_weights
return result
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond
results = []
tensors = []
conds_list = []
for composable_prompts in c.batch:
tensors = []
weights = []
conds_for_batch = []
for composable_prompt in composable_prompts:
target_index = 0
@@ -357,24 +340,19 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
target_index = current
break
weights.append(composable_prompt.weight)
conds_for_batch.append((len(tensors), composable_prompt.weight))
tensors.append(composable_prompt.schedules[target_index].cond)
if isinstance(tensors[0], dict):
weighted = {k: stack_conds_alter([x[k] for x in tensors], weights) for k in tensors[0].keys()}
else:
weighted = stack_conds_alter(tensors, weights)
conds_list.append(conds_for_batch)
results.append(weighted)
if isinstance(results[0], dict):
results = {k: torch.stack([x[k] for x in results])
for k in results[0].keys()}
results = DictWithShape(results, results['crossattn'].shape)
if isinstance(tensors[0], dict):
keys = list(tensors[0].keys())
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
else:
results = torch.stack(results).to(device=param.device, dtype=param.dtype)
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
return results
return conds_list, stacked
re_attention = re.compile(r"""

View File

@@ -157,7 +157,7 @@ class CFGDenoiser(torch.nn.Module):
cond = self.sampler.sampler_extra_args['cond']
uncond = self.sampler.sampler_extra_args['uncond']
cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
# If we use masks, blending between the denoised and original latent images occurs here.
@@ -179,7 +179,8 @@ class CFGDenoiser(torch.nn.Module):
denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self)
cfg_denoiser_callback(denoiser_params)
denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params, cond_scale=cond_scale)
denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params,
cond_scale=cond_scale, cond_composition=cond_composition)
preview = self.sampler.last_latent = denoised
sd_samplers_common.store_latent(preview)