Improvements to vae trainer. Adjust denoise prediction of DFE v3

This commit is contained in:
Jaret Burkett
2025-05-30 12:06:47 -06:00
parent ffaf2f154a
commit b6d25fcd10
3 changed files with 63 additions and 29 deletions

View File

@@ -255,30 +255,30 @@ class DiffusionFeatureExtractor3(nn.Module):
dtype = torch.bfloat16
device = self.vae.device
# first we step the scheduler from current timestep to the very end for a full denoise
# bs = noise_pred.shape[0]
# noise_pred_chunks = torch.chunk(noise_pred, bs)
# timestep_chunks = torch.chunk(timesteps, bs)
# noisy_latent_chunks = torch.chunk(noisy_latents, bs)
# stepped_chunks = []
# for idx in range(bs):
# model_output = noise_pred_chunks[idx]
# timestep = timestep_chunks[idx]
# scheduler._step_index = None
# scheduler._init_step_index(timestep)
# sample = noisy_latent_chunks[idx].to(torch.float32)
# sigma = scheduler.sigmas[scheduler.step_index]
# sigma_next = scheduler.sigmas[-1] # use last sigma for final step
# prev_sample = sample + (sigma_next - sigma) * model_output
# stepped_chunks.append(prev_sample)
# stepped_latents = torch.cat(stepped_chunks, dim=0)
if model is not None and hasattr(model, 'get_stepped_pred'):
stepped_latents = model.get_stepped_pred(noise_pred, noise)
else:
stepped_latents = noise - noise_pred
# stepped_latents = noise - noise_pred
# first we step the scheduler from current timestep to the very end for a full denoise
bs = noise_pred.shape[0]
noise_pred_chunks = torch.chunk(noise_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
stepped_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx]
timestep = timestep_chunks[idx]
scheduler._step_index = None
scheduler._init_step_index(timestep)
sample = noisy_latent_chunks[idx].to(torch.float32)
sigma = scheduler.sigmas[scheduler.step_index]
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
prev_sample = sample + (sigma_next - sigma) * model_output
stepped_chunks.append(prev_sample)
stepped_latents = torch.cat(stepped_chunks, dim=0)
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)