diff --git a/toolkit/guidance.py b/toolkit/guidance.py index d2deb278..84242423 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -24,8 +24,12 @@ def get_differential_mask( ): # make a differential mask differential_mask = torch.abs(conditional_latents - unconditional_latents) - max_differential = \ - differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + if len(differential_mask.shape) == 4: + max_differential = \ + differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0] + elif len(differential_mask.shape) == 5: + max_differential = \ + differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0].max(dim=4, keepdim=True)[0] differential_scaler = 1.0 / max_differential differential_mask = differential_mask * differential_scaler @@ -631,6 +635,14 @@ def targeted_flow_guidance( conditional_latents = batch.latents.to(device, dtype=dtype).detach() unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + # get a mask on the differential of the latents + # this will be scaled from 0.0-1.0 with 1.0 being the largest differential + abs_differential_mask = get_differential_mask( + conditional_latents, + unconditional_latents, + gradient=True + ) + # get noisy latents for both conditional and unconditional predictions unconditional_noisy_latents = sd.add_noise( unconditional_latents, @@ -664,8 +676,11 @@ def targeted_flow_guidance( baseline_predicted_noise = baseline_prediction + unconditional_latents # baseline_predicted_noise is now the noise prediction our model would make with a the unconditional image. - # we use this as our new noise target to preserve the existing knowledge of the image - target_noise = baseline_predicted_noise + # we use this as our new noise target to preserve the existing knowledge of the image. + # we apply a mask to this noise to only allow the differential of the conditional latents to be learned + baseline_predicted_noise = (1 - abs_differential_mask) * baseline_predicted_noise + masked_noise = abs_differential_mask * noise + target_noise = masked_noise + baseline_predicted_noise # compute our new target prediction using our current knowledge noise with our conditional latents # this makes it so the only new information is the differential of our conditional and unconditional latents