From fa187b1208d6bf777056b2073c164b905499b016 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 17 Mar 2025 13:25:01 -0600 Subject: [PATCH] Added differential masking to targeted_flow_guidance to allow the model learn to clean up the targeted area a little more than unmasked was capable of --- toolkit/guidance.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/toolkit/guidance.py b/toolkit/guidance.py index d2deb278..84242423 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -24,8 +24,12 @@ def get_differential_mask( ): # make a differential mask differential_mask = torch.abs(conditional_latents - unconditional_latents) - max_differential = \ - differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + if len(differential_mask.shape) == 4: + max_differential = \ + differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + elif len(differential_mask.shape) == 5: + max_differential = \ + differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0].max(dim=4, keepdim=True)[0] differential_scaler = 1.0 / max_differential differential_mask = differential_mask * differential_scaler @@ -631,6 +635,14 @@ def targeted_flow_guidance( conditional_latents = batch.latents.to(device, dtype=dtype).detach() unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + # get a mask on the differential of the latents + # this will be scaled from 0.0-1.0 with 1.0 being the largest differential + abs_differential_mask = get_differential_mask( + conditional_latents, + unconditional_latents, + gradient=True + ) + # get noisy latents for both conditional and unconditional predictions unconditional_noisy_latents = sd.add_noise( unconditional_latents, @@ -664,8 +676,11 @@ def targeted_flow_guidance( baseline_predicted_noise = baseline_prediction + unconditional_latents # baseline_predicted_noise is now the noise prediction our model would make with a the unconditional image. - # we use this as our new noise target to preserve the existing knowledge of the image - target_noise = baseline_predicted_noise + # we use this as our new noise target to preserve the existing knowledge of the image. + # we apply a mask to this noise to only allow the differential of the conditional latents to be learned + baseline_predicted_noise = (1 - abs_differential_mask) * baseline_predicted_noise + masked_noise = abs_differential_mask * noise + target_noise = masked_noise + baseline_predicted_noise # compute our new target prediction using our current knowledge noise with our conditional latents # this makes it so the only new information is the differential of our conditional and unconditional latents