mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 08:49:14 +00:00
Small fixed for DFE, polar guidance, and other things
This commit is contained in:
@@ -226,45 +226,48 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
return feats_list
|
||||
|
||||
# do lpips
|
||||
lpips_feat_list = [x.detach() for x in get_lpips_features(
|
||||
lpips_feat_list = [x for x in get_lpips_features(
|
||||
tensors_n1p1.to(device, dtype=torch.float32))]
|
||||
|
||||
return lpips_feat_list
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self,
|
||||
noise,
|
||||
noise_pred,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
batch: DataLoaderBatchDTO,
|
||||
scheduler: CustomFlowMatchEulerDiscreteScheduler,
|
||||
lpips_weight=20.0,
|
||||
lpips_weight=1.0,
|
||||
clip_weight=0.1,
|
||||
pixel_weight=1.0
|
||||
pixel_weight=0.1
|
||||
):
|
||||
dtype = torch.bfloat16
|
||||
device = self.vae.device
|
||||
|
||||
# first we step the scheduler from current timestep to the very end for a full denoise
|
||||
bs = noise_pred.shape[0]
|
||||
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||
timestep_chunks = torch.chunk(timesteps, bs)
|
||||
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
stepped_chunks = []
|
||||
for idx in range(bs):
|
||||
model_output = noise_pred_chunks[idx]
|
||||
timestep = timestep_chunks[idx]
|
||||
scheduler._step_index = None
|
||||
scheduler._init_step_index(timestep)
|
||||
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
# bs = noise_pred.shape[0]
|
||||
# noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||
# timestep_chunks = torch.chunk(timesteps, bs)
|
||||
# noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
# stepped_chunks = []
|
||||
# for idx in range(bs):
|
||||
# model_output = noise_pred_chunks[idx]
|
||||
# timestep = timestep_chunks[idx]
|
||||
# scheduler._step_index = None
|
||||
# scheduler._init_step_index(timestep)
|
||||
# sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
|
||||
sigma = scheduler.sigmas[scheduler.step_index]
|
||||
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
stepped_chunks.append(prev_sample)
|
||||
# sigma = scheduler.sigmas[scheduler.step_index]
|
||||
# sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
||||
# prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
# stepped_chunks.append(prev_sample)
|
||||
|
||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
# stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
stepped_latents = noise - noise_pred
|
||||
|
||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
@@ -274,16 +277,18 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
|
||||
pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1
|
||||
|
||||
pred_clip_output = self.get_siglip_features(pred_images)
|
||||
lpips_feat_list_pred = self.get_lpips_features(pred_images.float())
|
||||
|
||||
total_loss = 0
|
||||
|
||||
with torch.no_grad():
|
||||
target_img = batch.tensor.to(device, dtype=dtype)
|
||||
# go from -1 to 1 to 0 to 1
|
||||
target_img = (target_img + 1) / 2
|
||||
target_clip_output = self.get_siglip_features(target_img).detach()
|
||||
lpips_feat_list_target = self.get_lpips_features(target_img.float())
|
||||
|
||||
target_clip_output = self.get_siglip_features(target_img).detach()
|
||||
|
||||
pred_clip_output = self.get_siglip_features(pred_images)
|
||||
clip_loss = torch.nn.functional.mse_loss(
|
||||
pred_clip_output.float(), target_clip_output.float()
|
||||
) * clip_weight
|
||||
@@ -293,7 +298,7 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
else:
|
||||
self.losses['clip_loss'] += clip_loss.item()
|
||||
|
||||
total_loss = clip_loss
|
||||
total_loss += clip_loss
|
||||
|
||||
lpips_loss = 0
|
||||
for idx, lpips_feat in enumerate(lpips_feat_list_pred):
|
||||
@@ -308,14 +313,14 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
|
||||
total_loss += lpips_loss
|
||||
|
||||
mse_loss = torch.nn.functional.mse_loss(
|
||||
stepped_latents.float(), batch.latents.float()
|
||||
) * pixel_weight
|
||||
# mse_loss = torch.nn.functional.mse_loss(
|
||||
# stepped_latents.float(), batch.latents.float()
|
||||
# ) * pixel_weight
|
||||
|
||||
if 'pixel_loss' not in self.losses:
|
||||
self.losses['pixel_loss'] = mse_loss.item()
|
||||
else:
|
||||
self.losses['pixel_loss'] += mse_loss.item()
|
||||
# if 'pixel_loss' not in self.losses:
|
||||
# self.losses['pixel_loss'] = mse_loss.item()
|
||||
# else:
|
||||
# self.losses['pixel_loss'] += mse_loss.item()
|
||||
|
||||
if self.step % self.log_every == 0 and self.step > 0:
|
||||
print(f"DFE losses:")
|
||||
@@ -325,7 +330,7 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
print(f" - {key}: {self.losses[key]:.3e}")
|
||||
self.losses[key] = 0.0
|
||||
|
||||
total_loss += mse_loss
|
||||
# total_loss += mse_loss
|
||||
self.step += 1
|
||||
|
||||
return total_loss
|
||||
|
||||
Reference in New Issue
Block a user