mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Fixed issue with bucket dataloader corpping in too much. Added normalization capabilities to LoRA modules. Testing effects, but should prevent them from burning and also make them more compatable with stacking many LoRAs
This commit is contained in:
@@ -208,7 +208,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.network is not None:
|
||||
prev_multiplier = self.network.multiplier
|
||||
self.network.multiplier = 1.0
|
||||
# TODO handle dreambooth, fine tuning, etc
|
||||
if self.network_config.normalize:
|
||||
# apply the normalization
|
||||
self.network.apply_stored_normalizer()
|
||||
self.network.save_weights(
|
||||
file_path,
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
@@ -323,7 +325,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
latents = self.sd.encode_images(imgs)
|
||||
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
@@ -429,6 +430,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.network.enable_gradient_checkpointing()
|
||||
|
||||
# set the network to normalize if we are
|
||||
self.network.is_normalizing = self.network_config.normalize
|
||||
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
if latest_save_path is not None:
|
||||
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
@@ -522,71 +526,84 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dataloader_reg = None
|
||||
dataloader_iterator_reg = None
|
||||
|
||||
# zero any gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
if step % 2 == 0 and dataloader_reg is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator = iter(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
else:
|
||||
batch = None
|
||||
with torch.no_grad():
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
if step % 2 == 0 and dataloader_reg is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator = iter(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
else:
|
||||
batch = None
|
||||
|
||||
# turn on normalization if we are using it and it is not on
|
||||
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
|
||||
self.network.is_normalizing = True
|
||||
|
||||
### HOOK ###
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
flush()
|
||||
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
with torch.no_grad():
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
|
||||
prog_bar_string = f"lr: {learning_rate:.1e}"
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
prog_bar_string = f"lr: {learning_rate:.1e}"
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# don't do on first step
|
||||
if self.step_num != self.start_step:
|
||||
# pause progress bar
|
||||
self.progress_bar.unpause() # makes it so doesn't track time
|
||||
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
|
||||
# print above the progress bar
|
||||
self.sample(self.step_num)
|
||||
# don't do on first step
|
||||
if self.step_num != self.start_step:
|
||||
# pause progress bar
|
||||
self.progress_bar.unpause() # makes it so doesn't track time
|
||||
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
|
||||
# print above the progress bar
|
||||
self.sample(self.step_num)
|
||||
|
||||
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.refresh()
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.refresh()
|
||||
|
||||
# sets progress bar to match out step
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
# end of step
|
||||
self.step_num = step
|
||||
# sets progress bar to match out step
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
# end of step
|
||||
self.step_num = step
|
||||
|
||||
# apply network normalizer if we are using it
|
||||
if self.network is not None and self.network.is_normalizing:
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
self.sample(self.step_num + 1)
|
||||
print("")
|
||||
|
||||
Reference in New Issue
Block a user