Minor fixes

This commit is contained in:
Jaret Burkett
2024-06-23 14:47:40 -06:00
parent 5d47244c57
commit 64f2b085b7
3 changed files with 19 additions and 35 deletions

View File

@@ -339,7 +339,7 @@ class ModelConfig:
self.is_v2: bool = kwargs.get('is_v2', False)
self.is_xl: bool = kwargs.get('is_xl', False)
self.is_pixart: bool = kwargs.get('is_pixart', False)
self.is_pixart_sigma: bool = kwargs.get('is_pixart', False)
self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False)
self.is_v3: bool = kwargs.get('is_v3', False)
if self.is_pixart_sigma:
self.is_pixart = True

View File

@@ -240,7 +240,7 @@ def get_direct_guidance_loss(
noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0)
guidance_scale = 1.25
guidance_scale = 1.1
guidance_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
@@ -552,39 +552,17 @@ def get_guided_tnt(
reduction="none"
)
with torch.no_grad():
tnt_loss = this_loss - that_loss
# create a mask by scaling loss from 0 to mean to 1 to 0
# this will act to regularize unchanged areas to prior prediction
loss_min = tnt_loss.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0].min(dim=3, keepdim=True)[0]
loss_mean = tnt_loss.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
mask = value_map(
torch.abs(tnt_loss),
loss_min,
loss_mean,
0.0,
1.0
).clamp(0.0, 1.0).detach()
prior_mask = 1.0 - mask
this_loss = this_loss * mask
that_loss = that_loss * prior_mask
this_loss = this_loss.mean([1, 2, 3])
that_loss = that_loss.mean([1, 2, 3])
# negative loss on that
that_loss = -that_loss.mean([1, 2, 3])
prior_loss = torch.nn.functional.mse_loss(
this_prediction.float(),
prior_pred.detach().float(),
reduction="none"
)
prior_loss = prior_loss * prior_mask
prior_loss = prior_loss.mean([1, 2, 3])
with torch.no_grad():
# match that loss with this loss so it is not a negative value and same scale
that_loss_scaler = torch.abs(this_loss) / torch.abs(that_loss)
loss = prior_loss + this_loss - that_loss
that_loss = that_loss * that_loss_scaler * 0.01
loss = this_loss + that_loss
loss = loss.mean()