mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Improvements to vae trainer. Adjust denoise prediction of DFE v3
This commit is contained in:
@@ -89,6 +89,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.vae_config = self.get_conf('vae_config', None)
|
||||
self.dropout = self.get_conf('dropout', 0.0, as_type=float)
|
||||
self.train_encoder = self.get_conf('train_encoder', False, as_type=bool)
|
||||
self.random_scaling = self.get_conf('random_scaling', False, as_type=bool)
|
||||
|
||||
if not self.train_encoder:
|
||||
# remove losses that only target encoder
|
||||
@@ -159,7 +160,11 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
for dataset in self.datasets_objects:
|
||||
print(f" - Dataset: {dataset['path']}")
|
||||
ds = copy.copy(dataset)
|
||||
ds['resolution'] = self.resolution
|
||||
dataset_res = self.resolution
|
||||
if self.random_scaling:
|
||||
# scale 2x to allow for random scaling
|
||||
dataset_res = int(dataset_res * 2)
|
||||
ds['resolution'] = dataset_res
|
||||
image_dataset = ImageDataset(ds)
|
||||
datasets.append(image_dataset)
|
||||
|
||||
@@ -168,7 +173,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
concatenated_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=8
|
||||
num_workers=16
|
||||
)
|
||||
|
||||
def remove_oldest_checkpoint(self):
|
||||
@@ -573,6 +578,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
epoch_losses = copy.deepcopy(blank_losses)
|
||||
log_losses = copy.deepcopy(blank_losses)
|
||||
# range start at self.epoch_num go to self.epochs
|
||||
|
||||
latent_size = self.resolution // self.vae_scale_factor
|
||||
|
||||
for epoch in range(self.epoch_num, self.epochs, 1):
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
@@ -580,8 +588,20 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
with torch.no_grad():
|
||||
|
||||
batch = batch.to(self.device, dtype=self.torch_dtype)
|
||||
|
||||
if self.random_scaling:
|
||||
# only random scale 0.5 of the time
|
||||
if random.random() < 0.5:
|
||||
# random scale the batch
|
||||
scale_factor = 0.25
|
||||
else:
|
||||
scale_factor = 0.5
|
||||
new_size = (int(batch.shape[2] * scale_factor), int(batch.shape[3] * scale_factor))
|
||||
# make sure it is vae divisible
|
||||
new_size = (new_size[0] // self.vae_scale_factor * self.vae_scale_factor,
|
||||
new_size[1] // self.vae_scale_factor * self.vae_scale_factor)
|
||||
|
||||
|
||||
# resize so it matches size of vae evenly
|
||||
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
|
||||
@@ -615,6 +635,11 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if do_flip_y > 0:
|
||||
latent_chunks[i] = torch.flip(latent_chunks[i], [3])
|
||||
batch_chunks[i] = torch.flip(batch_chunks[i], [3])
|
||||
|
||||
# resize latent to fit
|
||||
if latent_chunks[i].shape[2] != latent_size or latent_chunks[i].shape[3] != latent_size:
|
||||
latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], size=(latent_size, latent_size), mode='bilinear', align_corners=False)
|
||||
|
||||
# if do_scale > 0:
|
||||
# scale = 2
|
||||
# start_latent_h = latent_chunks[i].shape[2]
|
||||
@@ -643,6 +668,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
forward_latents = channel_dropout(latents, self.dropout)
|
||||
else:
|
||||
forward_latents = latents
|
||||
|
||||
# resize batch to resolution if needed
|
||||
if batch_chunks[0].shape[2] != self.resolution or batch_chunks[0].shape[3] != self.resolution:
|
||||
batch_chunks = [torch.nn.functional.interpolate(b, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False) for b in batch_chunks]
|
||||
batch = torch.cat(batch_chunks, dim=0)
|
||||
|
||||
else:
|
||||
|
||||
@@ -220,10 +220,15 @@ class Critic:
|
||||
|
||||
return float(np.mean(critic_losses))
|
||||
|
||||
def get_lr(self):
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
return (
|
||||
self.optimizer.param_groups[0]["d"]
|
||||
* self.optimizer.param_groups[0]["lr"]
|
||||
def get_lr(self):
|
||||
if hasattr(self.optimizer, 'get_avg_learning_rate'):
|
||||
learning_rate = self.optimizer.get_avg_learning_rate()
|
||||
elif self.optimizer_type.startswith('dadaptation') or \
|
||||
self.optimizer_type.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
self.optimizer.param_groups[0]["d"] *
|
||||
self.optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
return self.optimizer.param_groups[0]["lr"]
|
||||
else:
|
||||
learning_rate = self.optimizer.param_groups[0]['lr']
|
||||
return learning_rate
|
||||
|
||||
Reference in New Issue
Block a user