mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Adjustments to guidance
This commit is contained in:
@@ -16,7 +16,7 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.data_loader import get_dataloader_datasets
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
@@ -1293,8 +1293,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
do_correct_pred_norm_prior = self.train_config.correct_pred_norm
|
||||
|
||||
do_guidance_prior = False
|
||||
|
||||
if batch.unconditional_latents is not None:
|
||||
# for this not that, we need a prior pred to normalize
|
||||
guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
|
||||
if guidance_type == 'tnt':
|
||||
do_guidance_prior = True
|
||||
|
||||
if ((
|
||||
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
|
||||
has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm):
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
|
||||
@@ -482,6 +482,7 @@ def get_guided_loss_polarity(
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
def get_guided_tnt(
|
||||
noisy_latents: torch.Tensor,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
@@ -492,6 +493,7 @@ def get_guided_tnt(
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
noise: torch.Tensor,
|
||||
sd: 'StableDiffusion',
|
||||
prior_pred: torch.Tensor = None,
|
||||
**kwargs
|
||||
):
|
||||
dtype = get_torch_dtype(sd.torch_dtype)
|
||||
@@ -528,6 +530,7 @@ def get_guided_tnt(
|
||||
sd.network.multiplier = cat_network_weight_list
|
||||
sd.network.is_active = True
|
||||
|
||||
|
||||
prediction = sd.predict_noise(
|
||||
latents=cat_latents.to(device, dtype=dtype).detach(),
|
||||
conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
|
||||
@@ -547,11 +550,41 @@ def get_guided_tnt(
|
||||
that_prediction.float(),
|
||||
noise.float(),
|
||||
reduction="none"
|
||||
) * -1.0
|
||||
)
|
||||
|
||||
loss = this_loss + that_loss
|
||||
with torch.no_grad():
|
||||
tnt_loss = this_loss - that_loss
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
# create a mask by scaling loss from 0 to mean to 1 to 0
|
||||
# this will act to regularize unchanged areas to prior prediction
|
||||
loss_min = tnt_loss.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=3, keepdim=True)[0]
|
||||
loss_mean = tnt_loss.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
|
||||
mask = value_map(
|
||||
torch.abs(tnt_loss),
|
||||
loss_min,
|
||||
loss_mean,
|
||||
0.0,
|
||||
1.0
|
||||
).clamp(0.0, 1.0).detach()
|
||||
|
||||
prior_mask = 1.0 - mask
|
||||
|
||||
|
||||
this_loss = this_loss * mask
|
||||
that_loss = that_loss * prior_mask
|
||||
|
||||
this_loss = this_loss.mean([1, 2, 3])
|
||||
that_loss = that_loss.mean([1, 2, 3])
|
||||
|
||||
prior_loss = torch.nn.functional.mse_loss(
|
||||
this_prediction.float(),
|
||||
prior_pred.detach().float(),
|
||||
reduction="none"
|
||||
)
|
||||
prior_loss = prior_loss * prior_mask
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
|
||||
loss = prior_loss + this_loss - that_loss
|
||||
|
||||
loss.backward()
|
||||
|
||||
@@ -612,7 +645,7 @@ def get_guidance_loss(
|
||||
)
|
||||
elif guidance_type == "tnt":
|
||||
assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
|
||||
return get_guided_loss_polarity(
|
||||
return get_guided_tnt(
|
||||
noisy_latents,
|
||||
conditional_embeds,
|
||||
match_adapter_assist,
|
||||
@@ -622,6 +655,7 @@ def get_guidance_loss(
|
||||
batch,
|
||||
noise,
|
||||
sd,
|
||||
prior_pred=prior_pred,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user