Various experiments and minor bug fixes for edge cases

This commit is contained in:
Jaret Burkett
2025-04-25 13:44:38 -06:00
parent 8ff85ba14f
commit 88b3fbae37
8 changed files with 170 additions and 122 deletions

View File

@@ -69,6 +69,7 @@ import transformers
import diffusers
import hashlib
from toolkit.util.blended_blur_noise import get_blended_blur_noise
from toolkit.util.get_model import get_model_class
def flush():
@@ -903,7 +904,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
return noise
def get_noise(self, latents, batch_size, dtype=torch.float32, batch: 'DataLoaderBatchDTO' = None):
def get_noise(
self,
latents,
batch_size,
dtype=torch.float32,
batch: 'DataLoaderBatchDTO' = None,
timestep=None,
):
if self.train_config.optimal_noise_pairing_samples > 1:
noise = self.get_optimal_noise(latents, dtype=dtype)
elif self.train_config.force_consistent_noise:
@@ -933,12 +941,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
# add to noise
noise += noise_shift
# standardize the noise
# shouldnt be needed?
# std = noise.std(dim=(2, 3), keepdim=True)
# normalizer = 1 / (std + 1e-6)
# noise = noise * normalizer
if self.train_config.blended_blur_noise:
noise = get_blended_blur_noise(
latents, noise, timestep
)
return noise
@@ -1193,7 +1200,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
timesteps = torch.stack(timesteps, dim=0)
# get noise
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch)
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps)
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
# this will negate any noise offsets
@@ -1924,10 +1931,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
start_step_num = self.step_num
did_first_flush = False
flush_next = False
for step in range(start_step_num, self.train_config.steps):
if self.train_config.do_paramiter_swapping:
self.optimizer.optimizer.swap_paramiters()
self.timer.start('train_loop')
if flush_next:
flush()
flush_next = False
if self.train_config.do_random_cfg:
self.train_config.do_cfg = True
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
@@ -2089,6 +2100,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
print_acc(f"\nSaving at step {self.step_num}")
self.save(self.step_num)
self.ensure_params_requires_grad()
# clear any grads
optimizer.zero_grad()
flush()
flush_next = True
if self.progress_bar is not None:
self.progress_bar.unpause()