mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added differential masking to targeted_flow_guidance to allow the model learn to clean up the targeted area a little more than unmasked was capable of
This commit is contained in:
@@ -24,8 +24,12 @@ def get_differential_mask(
|
||||
):
|
||||
# make a differential mask
|
||||
differential_mask = torch.abs(conditional_latents - unconditional_latents)
|
||||
max_differential = \
|
||||
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
||||
if len(differential_mask.shape) == 4:
|
||||
max_differential = \
|
||||
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
|
||||
elif len(differential_mask.shape) == 5:
|
||||
max_differential = \
|
||||
differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0].max(dim=4, keepdim=True)[0]
|
||||
differential_scaler = 1.0 / max_differential
|
||||
differential_mask = differential_mask * differential_scaler
|
||||
|
||||
@@ -631,6 +635,14 @@ def targeted_flow_guidance(
|
||||
conditional_latents = batch.latents.to(device, dtype=dtype).detach()
|
||||
unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
|
||||
|
||||
# get a mask on the differential of the latents
|
||||
# this will be scaled from 0.0-1.0 with 1.0 being the largest differential
|
||||
abs_differential_mask = get_differential_mask(
|
||||
conditional_latents,
|
||||
unconditional_latents,
|
||||
gradient=True
|
||||
)
|
||||
|
||||
# get noisy latents for both conditional and unconditional predictions
|
||||
unconditional_noisy_latents = sd.add_noise(
|
||||
unconditional_latents,
|
||||
@@ -664,8 +676,11 @@ def targeted_flow_guidance(
|
||||
baseline_predicted_noise = baseline_prediction + unconditional_latents
|
||||
|
||||
# baseline_predicted_noise is now the noise prediction our model would make with a the unconditional image.
|
||||
# we use this as our new noise target to preserve the existing knowledge of the image
|
||||
target_noise = baseline_predicted_noise
|
||||
# we use this as our new noise target to preserve the existing knowledge of the image.
|
||||
# we apply a mask to this noise to only allow the differential of the conditional latents to be learned
|
||||
baseline_predicted_noise = (1 - abs_differential_mask) * baseline_predicted_noise
|
||||
masked_noise = abs_differential_mask * noise
|
||||
target_noise = masked_noise + baseline_predicted_noise
|
||||
|
||||
# compute our new target prediction using our current knowledge noise with our conditional latents
|
||||
# this makes it so the only new information is the differential of our conditional and unconditional latents
|
||||
|
||||
Reference in New Issue
Block a user