mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add targeted flow guidance training for flow based models
This commit is contained in:
@@ -603,6 +603,94 @@ def get_guided_tnt(
|
||||
|
||||
return loss
|
||||
|
||||
def targeted_flow_guidance(
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: 'PromptEmbeds',
|
||||
match_adapter_assist: bool,
|
||||
network_weight_list: list,
|
||||
timesteps: torch.Tensor,
|
||||
pred_kwargs: dict,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
unconditional_embeds: Optional[PromptEmbeds] = None,
|
||||
mask_multiplier=None,
|
||||
prior_pred=None,
|
||||
scaler=None,
|
||||
train_config=None,
|
||||
**kwargs
|
||||
):
|
||||
if not sd.is_flow_matching:
|
||||
raise ValueError("targeted_flow only works on flow matching models")
|
||||
dtype = get_torch_dtype(sd.torch_dtype)
|
||||
device = sd.device_torch
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(dtype)
|
||||
noise = noise.to(device, dtype=dtype).detach()
|
||||
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
# get noisy latents for both conditional and unconditional predictions
|
||||
unconditional_noisy_latents = sd.add_noise(
|
||||
unconditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
conditional_noisy_latents = sd.add_noise(
|
||||
conditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
|
||||
# disable the lora to get a baseline prediction
|
||||
sd.network.is_active = False
|
||||
sd.unet.eval()
|
||||
|
||||
# get a baseline prediction of the model knowledge without the lora network
|
||||
# we do this with the unconditional noisy latents
|
||||
baseline_prediction = sd.predict_noise(
|
||||
latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
).detach()
|
||||
|
||||
# This is our normal flowmatching target
|
||||
# target = noise - latents
|
||||
# we need to target the baseline noise but with our conditional latents
|
||||
# to do this we first have to determine the baseline_prediction noise by reversing the flowmatching target
|
||||
baseline_predicted_noise = baseline_prediction + unconditional_latents
|
||||
|
||||
# baseline_predicted_noise is now the noise prediction our model would make with a the unconditional image.
|
||||
# we use this as our new noise target to preserve the existing knowledge of the image
|
||||
target_noise = baseline_predicted_noise
|
||||
|
||||
# compute our new target prediction using our current knowledge noise with our conditional latents
|
||||
# this makes it so the only new information is the differential of our conditional and unconditional latents
|
||||
# forcing the network to preserve existing knowledge, but learn only our changes
|
||||
target_pred = (target_noise - conditional_latents).detach()
|
||||
|
||||
# make a prediction with the lora network active
|
||||
sd.unet.train()
|
||||
sd.network.is_active = True
|
||||
sd.network.multiplier = network_weight_list
|
||||
prediction = sd.predict_noise(
|
||||
latents=conditional_noisy_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# target our baseline + diffirential noise target
|
||||
pred_loss = torch.nn.functional.mse_loss(
|
||||
prediction.float(),
|
||||
target_pred.float()
|
||||
)
|
||||
|
||||
return pred_loss
|
||||
|
||||
|
||||
# this processes all guidance losses based on the batch information
|
||||
@@ -702,5 +790,23 @@ def get_guidance_loss(
|
||||
prior_pred=prior_pred,
|
||||
**kwargs
|
||||
)
|
||||
elif guidance_type == "targeted_flow":
|
||||
return targeted_flow_guidance(
|
||||
noisy_latents,
|
||||
conditional_embeds,
|
||||
match_adapter_assist,
|
||||
network_weight_list,
|
||||
timesteps,
|
||||
pred_kwargs,
|
||||
batch,
|
||||
noise,
|
||||
sd,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
mask_multiplier=mask_multiplier,
|
||||
prior_pred=prior_pred,
|
||||
scaler=scaler,
|
||||
train_config=train_config,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Guidance type {guidance_type} is not implemented")
|
||||
|
||||
Reference in New Issue
Block a user