Fixed issue with bucket dataloader corpping in too much. Added normalization capabilities to LoRA modules. Testing effects, but should prevent them from burning and also make them more compatable with stacking many LoRAs

This commit is contained in:
Jaret Burkett
2023-08-27 09:40:01 -06:00
parent 6bd3851058
commit 9b164a8688
5 changed files with 190 additions and 103 deletions

View File

@@ -208,7 +208,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.network is not None:
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
# TODO handle dreambooth, fine tuning, etc
if self.network_config.normalize:
# apply the normalization
self.network.apply_stored_normalizer()
self.network.save_weights(
file_path,
dtype=get_torch_dtype(self.save_config.dtype),
@@ -323,7 +325,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
imgs = imgs.to(self.device_torch, dtype=dtype)
latents = self.sd.encode_images(imgs)
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
@@ -429,6 +430,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.train_config.gradient_checkpointing:
self.network.enable_gradient_checkpointing()
# set the network to normalize if we are
self.network.is_normalizing = self.network_config.normalize
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
@@ -522,71 +526,84 @@ class BaseSDTrainProcess(BaseTrainProcess):
dataloader_reg = None
dataloader_iterator_reg = None
# zero any gradients
optimizer.zero_grad()
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
# if is even step and we have a reg dataset, use that
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
if step % 2 == 0 and dataloader_reg is not None:
try:
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator_reg = iter(dataloader_reg)
batch = next(dataloader_iterator_reg)
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator = iter(dataloader)
batch = next(dataloader_iterator)
else:
batch = None
with torch.no_grad():
# if is even step and we have a reg dataset, use that
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
if step % 2 == 0 and dataloader_reg is not None:
try:
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator_reg = iter(dataloader_reg)
batch = next(dataloader_iterator_reg)
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator = iter(dataloader)
batch = next(dataloader_iterator)
else:
batch = None
# turn on normalization if we are using it and it is not on
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
self.network.is_normalizing = True
### HOOK ###
loss_dict = self.hook_train_loop(batch)
flush()
if self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
with torch.no_grad():
if self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
self.progress_bar.set_postfix_str(prog_bar_string)
self.progress_bar.set_postfix_str(prog_bar_string)
# don't do on first step
if self.step_num != self.start_step:
# pause progress bar
self.progress_bar.unpause() # makes it so doesn't track time
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
# print above the progress bar
self.sample(self.step_num)
# don't do on first step
if self.step_num != self.start_step:
# pause progress bar
self.progress_bar.unpause() # makes it so doesn't track time
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
# print above the progress bar
self.sample(self.step_num)
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
# print above the progress bar
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
# print above the progress bar
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
# log to tensorboard
if self.writer is not None:
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.refresh()
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
# log to tensorboard
if self.writer is not None:
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.refresh()
# sets progress bar to match out step
self.progress_bar.update(step - self.progress_bar.n)
# end of step
self.step_num = step
# sets progress bar to match out step
self.progress_bar.update(step - self.progress_bar.n)
# end of step
self.step_num = step
# apply network normalizer if we are using it
if self.network is not None and self.network.is_normalizing:
self.network.apply_stored_normalizer()
self.sample(self.step_num + 1)
print("")