Adjustments to guidance

This commit is contained in:
Jaret Burkett
2024-04-19 15:00:35 -06:00
parent 2d0a1be59d
commit 377b81ee3e
2 changed files with 48 additions and 6 deletions

View File

@@ -16,7 +16,7 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.config_modules import GuidanceConfig
from toolkit.data_loader import get_dataloader_datasets
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
from toolkit.custom_adapter import CustomAdapter
@@ -1293,8 +1293,16 @@ class SDTrainer(BaseSDTrainProcess):
do_correct_pred_norm_prior = self.train_config.correct_pred_norm
do_guidance_prior = False
if batch.unconditional_latents is not None:
# for this not that, we need a prior pred to normalize
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
if guidance_type == 'tnt':
do_guidance_prior = True
if ((
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
with self.timer('prior predict'):
prior_pred = self.get_prior_prediction(
noisy_latents=noisy_latents,

View File

@@ -482,6 +482,7 @@ def get_guided_loss_polarity(
return loss
def get_guided_tnt(
noisy_latents: torch.Tensor,
conditional_embeds: PromptEmbeds,
@@ -492,6 +493,7 @@ def get_guided_tnt(
batch: 'DataLoaderBatchDTO',
noise: torch.Tensor,
sd: 'StableDiffusion',
prior_pred: torch.Tensor = None,
**kwargs
):
dtype = get_torch_dtype(sd.torch_dtype)
@@ -528,6 +530,7 @@ def get_guided_tnt(
sd.network.multiplier = cat_network_weight_list
sd.network.is_active = True
prediction = sd.predict_noise(
latents=cat_latents.to(device, dtype=dtype).detach(),
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
@@ -547,11 +550,41 @@ def get_guided_tnt(
that_prediction.float(),
noise.float(),
reduction="none"
) * -1.0
)
loss = this_loss + that_loss
with torch.no_grad():
tnt_loss = this_loss - that_loss
loss = loss.mean([1, 2, 3])
# create a mask by scaling loss from 0 to mean to 1 to 0
# this will act to regularize unchanged areas to prior prediction
loss_min = tnt_loss.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=3, keepdim=True)[0]
loss_mean = tnt_loss.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
mask = value_map(
torch.abs(tnt_loss),
loss_min,
loss_mean,
0.0,
1.0
).clamp(0.0, 1.0).detach()
prior_mask = 1.0 - mask
this_loss = this_loss * mask
that_loss = that_loss * prior_mask
this_loss = this_loss.mean([1, 2, 3])
that_loss = that_loss.mean([1, 2, 3])
prior_loss = torch.nn.functional.mse_loss(
this_prediction.float(),
prior_pred.detach().float(),
reduction="none"
)
prior_loss = prior_loss * prior_mask
prior_loss = prior_loss.mean([1, 2, 3])
loss = prior_loss + this_loss - that_loss
loss.backward()
@@ -612,7 +645,7 @@ def get_guidance_loss(
)
elif guidance_type == "tnt":
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
return get_guided_loss_polarity(
return get_guided_tnt(
noisy_latents,
conditional_embeds,
match_adapter_assist,
@@ -622,6 +655,7 @@ def get_guidance_loss(
batch,
noise,
sd,
prior_pred=prior_pred,
**kwargs
)