Add targeted flow guidance training for flow based models

This commit is contained in:
Jaret Burkett
2025-03-17 09:21:23 -06:00
parent 604e76d34d
commit 5eb627dd9d

View File

@@ -603,6 +603,94 @@ def get_guided_tnt(
return loss
def targeted_flow_guidance(
noisy_latents: torch.Tensor,
conditional_embeds: 'PromptEmbeds',
match_adapter_assist: bool,
network_weight_list: list,
timesteps: torch.Tensor,
pred_kwargs: dict,
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
unconditional_embeds: Optional[PromptEmbeds] = None,
mask_multiplier=None,
prior_pred=None,
scaler=None,
train_config=None,
**kwargs
):
if not sd.is_flow_matching:
raise ValueError("targeted_flow only works on flow matching models")
dtype = get_torch_dtype(sd.torch_dtype)
device = sd.device_torch
with torch.no_grad():
dtype = get_torch_dtype(dtype)
noise = noise.to(device, dtype=dtype).detach()
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
# get noisy latents for both conditional and unconditional predictions
unconditional_noisy_latents = sd.add_noise(
unconditional_latents,
noise,
timesteps
).detach()
conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
# disable the lora to get a baseline prediction
sd.network.is_active = False
sd.unet.eval()
# get a baseline prediction of the model knowledge without the lora network
# we do this with the unconditional noisy latents
baseline_prediction = sd.predict_noise(
latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
).detach()
# This is our normal flowmatching target
# target = noise - latents
# we need to target the baseline noise but with our conditional latents
# to do this we first have to determine the baseline_prediction noise by reversing the flowmatching target
baseline_predicted_noise = baseline_prediction + unconditional_latents
# baseline_predicted_noise is now the noise prediction our model would make with a the unconditional image.
# we use this as our new noise target to preserve the existing knowledge of the image
target_noise = baseline_predicted_noise
# compute our new target prediction using our current knowledge noise with our conditional latents
# this makes it so the only new information is the differential of our conditional and unconditional latents
# forcing the network to preserve existing knowledge, but learn only our changes
target_pred = (target_noise - conditional_latents).detach()
# make a prediction with the lora network active
sd.unet.train()
sd.network.is_active = True
sd.network.multiplier = network_weight_list
prediction = sd.predict_noise(
latents=conditional_noisy_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
# target our baseline + diffirential noise target
pred_loss = torch.nn.functional.mse_loss(
prediction.float(),
target_pred.float()
)
return pred_loss
# this processes all guidance losses based on the batch information
@@ -702,5 +790,23 @@ def get_guidance_loss(
prior_pred=prior_pred,
**kwargs
)
elif guidance_type == "targeted_flow":
return targeted_flow_guidance(
noisy_latents,
conditional_embeds,
match_adapter_assist,
network_weight_list,
timesteps,
pred_kwargs,
batch,
noise,
sd,
unconditional_embeds=unconditional_embeds,
mask_multiplier=mask_multiplier,
prior_pred=prior_pred,
scaler=scaler,
train_config=train_config,
**kwargs
)
else:
raise NotImplementedError(f"Guidance type {guidance_type} is not implemented")