From 5eb627dd9d0fdece98e0ecd304a4f70bd0d420d9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 17 Mar 2025 09:21:23 -0600 Subject: [PATCH] Add targeted flow guidance training for flow based models --- toolkit/guidance.py | 106 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/toolkit/guidance.py b/toolkit/guidance.py index b9971dc1..d2deb278 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -603,6 +603,94 @@ def get_guided_tnt( return loss +def targeted_flow_guidance( + noisy_latents: torch.Tensor, + conditional_embeds: 'PromptEmbeds', + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + unconditional_embeds: Optional[PromptEmbeds] = None, + mask_multiplier=None, + prior_pred=None, + scaler=None, + train_config=None, + **kwargs +): + if not sd.is_flow_matching: + raise ValueError("targeted_flow only works on flow matching models") + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + with torch.no_grad(): + dtype = get_torch_dtype(dtype) + noise = noise.to(device, dtype=dtype).detach() + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + # get noisy latents for both conditional and unconditional predictions + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + conditional_noisy_latents = sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + + # disable the lora to get a baseline prediction + sd.network.is_active = False + sd.unet.eval() + + # get a baseline prediction of the model knowledge without the lora network + # we do this with the unconditional noisy latents + baseline_prediction = sd.predict_noise( + latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs + ).detach() + + # This is our normal flowmatching target + # target = noise - latents + # we need to target the baseline noise but with our conditional latents + # to do this we first have to determine the baseline_prediction noise by reversing the flowmatching target + baseline_predicted_noise = baseline_prediction + unconditional_latents + + # baseline_predicted_noise is now the noise prediction our model would make with a the unconditional image. + # we use this as our new noise target to preserve the existing knowledge of the image + target_noise = baseline_predicted_noise + + # compute our new target prediction using our current knowledge noise with our conditional latents + # this makes it so the only new information is the differential of our conditional and unconditional latents + # forcing the network to preserve existing knowledge, but learn only our changes + target_pred = (target_noise - conditional_latents).detach() + + # make a prediction with the lora network active + sd.unet.train() + sd.network.is_active = True + sd.network.multiplier = network_weight_list + prediction = sd.predict_noise( + latents=conditional_noisy_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs + ) + + # target our baseline + diffirential noise target + pred_loss = torch.nn.functional.mse_loss( + prediction.float(), + target_pred.float() + ) + + return pred_loss # this processes all guidance losses based on the batch information @@ -702,5 +790,23 @@ def get_guidance_loss( prior_pred=prior_pred, **kwargs ) + elif guidance_type == "targeted_flow": + return targeted_flow_guidance( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + scaler=scaler, + train_config=train_config, + **kwargs + ) else: raise NotImplementedError(f"Guidance type {guidance_type} is not implemented")