mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-11 13:39:50 +00:00
Improvements to vae trainer. Adjust denoise prediction of DFE v3
This commit is contained in:
@@ -255,30 +255,30 @@ class DiffusionFeatureExtractor3(nn.Module):
|
||||
dtype = torch.bfloat16
|
||||
device = self.vae.device
|
||||
|
||||
# first we step the scheduler from current timestep to the very end for a full denoise
|
||||
# bs = noise_pred.shape[0]
|
||||
# noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||
# timestep_chunks = torch.chunk(timesteps, bs)
|
||||
# noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
# stepped_chunks = []
|
||||
# for idx in range(bs):
|
||||
# model_output = noise_pred_chunks[idx]
|
||||
# timestep = timestep_chunks[idx]
|
||||
# scheduler._step_index = None
|
||||
# scheduler._init_step_index(timestep)
|
||||
# sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
|
||||
# sigma = scheduler.sigmas[scheduler.step_index]
|
||||
# sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
||||
# prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
# stepped_chunks.append(prev_sample)
|
||||
|
||||
# stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
if model is not None and hasattr(model, 'get_stepped_pred'):
|
||||
stepped_latents = model.get_stepped_pred(noise_pred, noise)
|
||||
else:
|
||||
stepped_latents = noise - noise_pred
|
||||
# stepped_latents = noise - noise_pred
|
||||
# first we step the scheduler from current timestep to the very end for a full denoise
|
||||
bs = noise_pred.shape[0]
|
||||
noise_pred_chunks = torch.chunk(noise_pred, bs)
|
||||
timestep_chunks = torch.chunk(timesteps, bs)
|
||||
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
|
||||
stepped_chunks = []
|
||||
for idx in range(bs):
|
||||
model_output = noise_pred_chunks[idx]
|
||||
timestep = timestep_chunks[idx]
|
||||
scheduler._step_index = None
|
||||
scheduler._init_step_index(timestep)
|
||||
sample = noisy_latent_chunks[idx].to(torch.float32)
|
||||
|
||||
sigma = scheduler.sigmas[scheduler.step_index]
|
||||
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
stepped_chunks.append(prev_sample)
|
||||
|
||||
stepped_latents = torch.cat(stepped_chunks, dim=0)
|
||||
|
||||
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user