mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Various experiments and minor bug fixes for edge cases
This commit is contained in:
@@ -69,6 +69,7 @@ import transformers
|
||||
import diffusers
|
||||
import hashlib
|
||||
|
||||
from toolkit.util.blended_blur_noise import get_blended_blur_noise
|
||||
from toolkit.util.get_model import get_model_class
|
||||
|
||||
def flush():
|
||||
@@ -903,7 +904,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
return noise
|
||||
|
||||
|
||||
def get_noise(self, latents, batch_size, dtype=torch.float32, batch: 'DataLoaderBatchDTO' = None):
|
||||
def get_noise(
|
||||
self,
|
||||
latents,
|
||||
batch_size,
|
||||
dtype=torch.float32,
|
||||
batch: 'DataLoaderBatchDTO' = None,
|
||||
timestep=None,
|
||||
):
|
||||
if self.train_config.optimal_noise_pairing_samples > 1:
|
||||
noise = self.get_optimal_noise(latents, dtype=dtype)
|
||||
elif self.train_config.force_consistent_noise:
|
||||
@@ -933,12 +941,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# add to noise
|
||||
noise += noise_shift
|
||||
|
||||
# standardize the noise
|
||||
# shouldnt be needed?
|
||||
# std = noise.std(dim=(2, 3), keepdim=True)
|
||||
# normalizer = 1 / (std + 1e-6)
|
||||
# noise = noise * normalizer
|
||||
|
||||
if self.train_config.blended_blur_noise:
|
||||
noise = get_blended_blur_noise(
|
||||
latents, noise, timestep
|
||||
)
|
||||
|
||||
return noise
|
||||
|
||||
@@ -1193,7 +1200,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
timesteps = torch.stack(timesteps, dim=0)
|
||||
|
||||
# get noise
|
||||
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch)
|
||||
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps)
|
||||
|
||||
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
|
||||
# this will negate any noise offsets
|
||||
@@ -1924,10 +1931,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
start_step_num = self.step_num
|
||||
did_first_flush = False
|
||||
flush_next = False
|
||||
for step in range(start_step_num, self.train_config.steps):
|
||||
if self.train_config.do_paramiter_swapping:
|
||||
self.optimizer.optimizer.swap_paramiters()
|
||||
self.timer.start('train_loop')
|
||||
if flush_next:
|
||||
flush()
|
||||
flush_next = False
|
||||
if self.train_config.do_random_cfg:
|
||||
self.train_config.do_cfg = True
|
||||
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
|
||||
@@ -2089,6 +2100,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print_acc(f"\nSaving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
# clear any grads
|
||||
optimizer.zero_grad()
|
||||
flush()
|
||||
flush_next = True
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.unpause()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user