Improvements to vae trainer. Adjust denoise prediction of DFE v3

This commit is contained in:
Jaret Burkett
2025-05-30 12:06:47 -06:00
parent ffaf2f154a
commit b6d25fcd10
3 changed files with 63 additions and 29 deletions

View File

@@ -89,6 +89,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.vae_config = self.get_conf('vae_config', None)
self.dropout = self.get_conf('dropout', 0.0, as_type=float)
self.train_encoder = self.get_conf('train_encoder', False, as_type=bool)
self.random_scaling = self.get_conf('random_scaling', False, as_type=bool)
if not self.train_encoder:
# remove losses that only target encoder
@@ -159,7 +160,11 @@ class TrainVAEProcess(BaseTrainProcess):
for dataset in self.datasets_objects:
print(f" - Dataset: {dataset['path']}")
ds = copy.copy(dataset)
ds['resolution'] = self.resolution
dataset_res = self.resolution
if self.random_scaling:
# scale 2x to allow for random scaling
dataset_res = int(dataset_res * 2)
ds['resolution'] = dataset_res
image_dataset = ImageDataset(ds)
datasets.append(image_dataset)
@@ -168,7 +173,7 @@ class TrainVAEProcess(BaseTrainProcess):
concatenated_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=8
num_workers=16
)
def remove_oldest_checkpoint(self):
@@ -573,6 +578,9 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses = copy.deepcopy(blank_losses)
log_losses = copy.deepcopy(blank_losses)
# range start at self.epoch_num go to self.epochs
latent_size = self.resolution // self.vae_scale_factor
for epoch in range(self.epoch_num, self.epochs, 1):
if self.step_num >= self.max_steps:
break
@@ -580,8 +588,20 @@ class TrainVAEProcess(BaseTrainProcess):
if self.step_num >= self.max_steps:
break
with torch.no_grad():
batch = batch.to(self.device, dtype=self.torch_dtype)
if self.random_scaling:
# only random scale 0.5 of the time
if random.random() < 0.5:
# random scale the batch
scale_factor = 0.25
else:
scale_factor = 0.5
new_size = (int(batch.shape[2] * scale_factor), int(batch.shape[3] * scale_factor))
# make sure it is vae divisible
new_size = (new_size[0] // self.vae_scale_factor * self.vae_scale_factor,
new_size[1] // self.vae_scale_factor * self.vae_scale_factor)
# resize so it matches size of vae evenly
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
@@ -615,6 +635,11 @@ class TrainVAEProcess(BaseTrainProcess):
if do_flip_y > 0:
latent_chunks[i] = torch.flip(latent_chunks[i], [3])
batch_chunks[i] = torch.flip(batch_chunks[i], [3])
# resize latent to fit
if latent_chunks[i].shape[2] != latent_size or latent_chunks[i].shape[3] != latent_size:
latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], size=(latent_size, latent_size), mode='bilinear', align_corners=False)
# if do_scale > 0:
# scale = 2
# start_latent_h = latent_chunks[i].shape[2]
@@ -643,6 +668,10 @@ class TrainVAEProcess(BaseTrainProcess):
forward_latents = channel_dropout(latents, self.dropout)
else:
forward_latents = latents
# resize batch to resolution if needed
if batch_chunks[0].shape[2] != self.resolution or batch_chunks[0].shape[3] != self.resolution:
batch_chunks = [torch.nn.functional.interpolate(b, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False) for b in batch_chunks]
batch = torch.cat(batch_chunks, dim=0)
else:

View File

@@ -220,10 +220,15 @@ class Critic:
return float(np.mean(critic_losses))
def get_lr(self):
if self.optimizer_type.startswith('dadaptation'):
return (
self.optimizer.param_groups[0]["d"]
* self.optimizer.param_groups[0]["lr"]
def get_lr(self):
if hasattr(self.optimizer, 'get_avg_learning_rate'):
learning_rate = self.optimizer.get_avg_learning_rate()
elif self.optimizer_type.startswith('dadaptation') or \
self.optimizer_type.lower().startswith('prodigy'):
learning_rate = (
self.optimizer.param_groups[0]["d"] *
self.optimizer.param_groups[0]["lr"]
)
return self.optimizer.param_groups[0]["lr"]
else:
learning_rate = self.optimizer.param_groups[0]['lr']
return learning_rate

View File

@@ -255,30 +255,30 @@ class DiffusionFeatureExtractor3(nn.Module):
dtype = torch.bfloat16
device = self.vae.device
# first we step the scheduler from current timestep to the very end for a full denoise
# bs = noise_pred.shape[0]
# noise_pred_chunks = torch.chunk(noise_pred, bs)
# timestep_chunks = torch.chunk(timesteps, bs)
# noisy_latent_chunks = torch.chunk(noisy_latents, bs)
# stepped_chunks = []
# for idx in range(bs):
# model_output = noise_pred_chunks[idx]
# timestep = timestep_chunks[idx]
# scheduler._step_index = None
# scheduler._init_step_index(timestep)
# sample = noisy_latent_chunks[idx].to(torch.float32)
# sigma = scheduler.sigmas[scheduler.step_index]
# sigma_next = scheduler.sigmas[-1] # use last sigma for final step
# prev_sample = sample + (sigma_next - sigma) * model_output
# stepped_chunks.append(prev_sample)
# stepped_latents = torch.cat(stepped_chunks, dim=0)
if model is not None and hasattr(model, 'get_stepped_pred'):
stepped_latents = model.get_stepped_pred(noise_pred, noise)
else:
stepped_latents = noise - noise_pred
# stepped_latents = noise - noise_pred
# first we step the scheduler from current timestep to the very end for a full denoise
bs = noise_pred.shape[0]
noise_pred_chunks = torch.chunk(noise_pred, bs)
timestep_chunks = torch.chunk(timesteps, bs)
noisy_latent_chunks = torch.chunk(noisy_latents, bs)
stepped_chunks = []
for idx in range(bs):
model_output = noise_pred_chunks[idx]
timestep = timestep_chunks[idx]
scheduler._step_index = None
scheduler._init_step_index(timestep)
sample = noisy_latent_chunks[idx].to(torch.float32)
sigma = scheduler.sigmas[scheduler.step_index]
sigma_next = scheduler.sigmas[-1] # use last sigma for final step
prev_sample = sample + (sigma_next - sigma) * model_output
stepped_chunks.append(prev_sample)
stepped_latents = torch.cat(stepped_chunks, dim=0)
latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)