Added differential masking to targeted_flow_guidance to allow the model learn to clean up the targeted area a little more than unmasked was capable of

This commit is contained in:
Jaret Burkett
2025-03-17 13:25:01 -06:00
parent 5eb627dd9d
commit fa187b1208

View File

@@ -24,8 +24,12 @@ def get_differential_mask(
):
# make a differential mask
differential_mask = torch.abs(conditional_latents - unconditional_latents)
max_differential = \
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
if len(differential_mask.shape) == 4:
max_differential = \
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
elif len(differential_mask.shape) == 5:
max_differential = \
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0].max(dim=4, keepdim=True)[0]
differential_scaler = 1.0 / max_differential
differential_mask = differential_mask * differential_scaler
@@ -631,6 +635,14 @@ def targeted_flow_guidance(
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
# get a mask on the differential of the latents
# this will be scaled from 0.0-1.0 with 1.0 being the largest differential
abs_differential_mask = get_differential_mask(
conditional_latents,
unconditional_latents,
gradient=True
)
# get noisy latents for both conditional and unconditional predictions
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
@@ -664,8 +676,11 @@ def targeted_flow_guidance(
baseline_predicted_noise = baseline_prediction + unconditional_latents
# baseline_predicted_noise is now the noise prediction our model would make with a the unconditional image.
# we use this as our new noise target to preserve the existing knowledge of the image
target_noise = baseline_predicted_noise
# we use this as our new noise target to preserve the existing knowledge of the image.
# we apply a mask to this noise to only allow the differential of the conditional latents to be learned
baseline_predicted_noise = (1 - abs_differential_mask) * baseline_predicted_noise
masked_noise = abs_differential_mask * noise
target_noise = masked_noise + baseline_predicted_noise
# compute our new target prediction using our current knowledge noise with our conditional latents
# this makes it so the only new information is the differential of our conditional and unconditional latents