Fixed issue with bucket dataloader corpping in too much. Added normalization capabilities to LoRA modules. Testing effects, but should prevent them from burning and also make them more compatable with stacking many LoRAs

This commit is contained in:
Jaret Burkett
2023-08-27 09:40:01 -06:00
parent 6bd3851058
commit 9b164a8688
5 changed files with 190 additions and 103 deletions

View File

@@ -208,7 +208,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.network is not None:
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
# TODO handle dreambooth, fine tuning, etc
if self.network_config.normalize:
# apply the normalization
self.network.apply_stored_normalizer()
self.network.save_weights(
file_path,
dtype=get_torch_dtype(self.save_config.dtype),
@@ -323,7 +325,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
imgs = imgs.to(self.device_torch, dtype=dtype)
latents = self.sd.encode_images(imgs)
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
@@ -429,6 +430,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.train_config.gradient_checkpointing:
self.network.enable_gradient_checkpointing()
# set the network to normalize if we are
self.network.is_normalizing = self.network_config.normalize
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
@@ -522,71 +526,84 @@ class BaseSDTrainProcess(BaseTrainProcess):
dataloader_reg = None
dataloader_iterator_reg = None
# zero any gradients
optimizer.zero_grad()
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
# if is even step and we have a reg dataset, use that
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
if step % 2 == 0 and dataloader_reg is not None:
try:
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator_reg = iter(dataloader_reg)
batch = next(dataloader_iterator_reg)
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator = iter(dataloader)
batch = next(dataloader_iterator)
else:
batch = None
with torch.no_grad():
# if is even step and we have a reg dataset, use that
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
if step % 2 == 0 and dataloader_reg is not None:
try:
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator_reg = iter(dataloader_reg)
batch = next(dataloader_iterator_reg)
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
except StopIteration:
# hit the end of an epoch, reset
dataloader_iterator = iter(dataloader)
batch = next(dataloader_iterator)
else:
batch = None
# turn on normalization if we are using it and it is not on
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
self.network.is_normalizing = True
### HOOK ###
loss_dict = self.hook_train_loop(batch)
flush()
if self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
with torch.no_grad():
if self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
self.progress_bar.set_postfix_str(prog_bar_string)
self.progress_bar.set_postfix_str(prog_bar_string)
# don't do on first step
if self.step_num != self.start_step:
# pause progress bar
self.progress_bar.unpause() # makes it so doesn't track time
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
# print above the progress bar
self.sample(self.step_num)
# don't do on first step
if self.step_num != self.start_step:
# pause progress bar
self.progress_bar.unpause() # makes it so doesn't track time
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
# print above the progress bar
self.sample(self.step_num)
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
# print above the progress bar
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
# print above the progress bar
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
# log to tensorboard
if self.writer is not None:
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.refresh()
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
# log to tensorboard
if self.writer is not None:
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.refresh()
# sets progress bar to match out step
self.progress_bar.update(step - self.progress_bar.n)
# end of step
self.step_num = step
# sets progress bar to match out step
self.progress_bar.update(step - self.progress_bar.n)
# end of step
self.step_num = step
# apply network normalizer if we are using it
if self.network is not None and self.network.is_normalizing:
self.network.apply_stored_normalizer()
self.sample(self.step_num + 1)
print("")

View File

@@ -48,6 +48,7 @@ class NetworkConfig:
self.alpha: float = kwargs.get('alpha', 1.0)
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
self.normalize = kwargs.get('normalize', False)
class EmbeddingConfig:

View File

@@ -95,34 +95,40 @@ class BucketsMixin:
# the other dimension should be the same ratio it is now (bigger)
new_width = resolution
new_height = resolution
new_x = file_item.crop_x
new_y = file_item.crop_y
if width > height:
# scale width to match new resolution,
new_width = int(width * (resolution / height))
file_item.crop_width = new_width
file_item.scale_to_width = new_width
file_item.crop_height = resolution
file_item.scale_to_height = resolution
# make sure new_width is divisible by bucket_tolerance
if new_width % bucket_tolerance != 0:
# reduce it to the nearest divisible number
reduction = new_width % bucket_tolerance
new_width = new_width - reduction
file_item.crop_width = new_width - reduction
# adjust the new x position so we evenly crop
new_x = int(new_x + (reduction / 2))
file_item.crop_x = int(file_item.crop_x + (reduction / 2))
elif height > width:
# scale height to match new resolution
new_height = int(height * (resolution / width))
file_item.crop_height = new_height
file_item.scale_to_height = new_height
file_item.scale_to_width = resolution
file_item.crop_width = resolution
# make sure new_height is divisible by bucket_tolerance
if new_height % bucket_tolerance != 0:
# reduce it to the nearest divisible number
reduction = new_height % bucket_tolerance
new_height = new_height - reduction
file_item.crop_height = new_height - reduction
# adjust the new x position so we evenly crop
new_y = int(new_y + (reduction / 2))
# add info to file
file_item.crop_x = new_x
file_item.crop_y = new_y
file_item.crop_width = new_width
file_item.crop_height = new_height
file_item.crop_y = int(file_item.crop_y + (reduction / 2))
else:
# square image
file_item.crop_height = resolution
file_item.scale_to_height = resolution
file_item.scale_to_width = resolution
file_item.crop_width = resolution
# check if bucket exists, if not, create it
bucket_key = f'{new_width}x{new_height}'

View File

@@ -8,6 +8,7 @@ import torch
from transformers import CLIPTextModel
from .paths import SD_SCRIPTS_ROOT
from .train_tools import get_torch_dtype
sys.path.append(SD_SCRIPTS_ROOT)
@@ -78,6 +79,8 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
def apply_to(self):
self.org_forward = self.org_module.forward
@@ -91,8 +94,8 @@ class LoRAModule(torch.nn.Module):
batch_size = lora_up.size(0)
# batch will have all negative prompts first and positive prompts second
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
# if there is more than our multiplier, it is liekly a batch size increase, so we need to
# interleve the multipliers
# if there is more than our multiplier, it is likely a batch size increase, so we need to
# interleave the multipliers
if isinstance(self.multiplier, list):
if len(self.multiplier) == 0:
# single item, just return it
@@ -153,25 +156,30 @@ class LoRAModule(torch.nn.Module):
return lx * multiplier * scale
def create_custom_forward(self):
def custom_forward(*inputs):
return self._call_forward(*inputs)
return custom_forward
def forward(self, x):
org_forwarded = self.org_forward(x)
# TODO this just loses the grad. Not sure why. Probably why no one else is doing it either
# if torch.is_grad_enabled() and self.is_checkpointing and self.training:
# lora_output = checkpoint(
# self.create_custom_forward(),
# x,
# )
# else:
# lora_output = self._call_forward(x)
lora_output = self._call_forward(x)
if self.is_normalizing:
# get a dim array from orig forward that had index of all dimensions except the batch and channel
# Calculate the target magnitude for the combined output
orig_max = torch.max(torch.abs(org_forwarded))
# Calculate the additional increase in magnitude that lora_output would introduce
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded))
epsilon = 1e-6 # Small constant to avoid division by zero
# Calculate the scaling factor for the lora_output
# to ensure that the potential increase in magnitude doesn't change the original max
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
# save the scaler so it can be applied later
self.normalize_scaler = normalize_scaler.clone().detach()
lora_output *= normalize_scaler
return org_forwarded + lora_output
def enable_gradient_checkpointing(self):
@@ -180,11 +188,39 @@ class LoRAModule(torch.nn.Module):
def disable_gradient_checkpointing(self):
self.is_checkpointing = False
@torch.no_grad()
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
"""
Applied the previous normalization calculation to the module.
This must be called before saving or normalization will be lost.
It is probably best to call after each batch as well.
We just scale the up down weights to match this vector
:return:
"""
# get state dict
state_dict = self.state_dict()
dtype = state_dict['lora_up.weight'].dtype
device = state_dict['lora_up.weight'].device
# todo should we do this at fp32?
total_module_scale = torch.tensor(self.normalize_scaler / target_normalize_scaler) \
.to(device, dtype=dtype)
num_modules_layers = 2 # up and down
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
.to(device, dtype=dtype)
# apply the scaler to the up and down weights
for key in state_dict.keys():
if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
# do it inplace do params are updated
state_dict[key] *= up_down_scale
# reset the normalization scaler
self.normalize_scaler = target_normalize_scaler
class LoRASpecialNetwork(LoRANetwork):
_multiplier: float = 1.0
is_active: bool = False
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
@@ -230,7 +266,6 @@ class LoRASpecialNetwork(LoRANetwork):
"""
# call the parent of the parent we are replacing (LoRANetwork) init
super(LoRANetwork, self).__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
@@ -240,6 +275,11 @@ class LoRASpecialNetwork(LoRANetwork):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.is_checkpointing = False
self._multiplier: float = 1.0
self.is_active: bool = False
self._is_normalizing: bool = False
# triggers the state updates
self.multiplier = multiplier
if modules_dim is not None:
print(f"create LoRA network from weights")
@@ -451,21 +491,20 @@ class LoRASpecialNetwork(LoRANetwork):
for lora in loras:
lora.to(device, dtype)
def get_all_modules(self):
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras
if hasattr(self, 'text_encoder_loras'):
loras += self.text_encoder_loras
return loras
def _update_checkpointing(self):
if self.is_checkpointing:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.enable_gradient_checkpointing()
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.enable_gradient_checkpointing()
else:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.disable_gradient_checkpointing()
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.disable_gradient_checkpointing()
for module in self.get_all_modules():
if self.is_checkpointing:
module.enable_gradient_checkpointing()
else:
module.disable_gradient_checkpointing()
def enable_gradient_checkpointing(self):
# not supported
@@ -476,3 +515,17 @@ class LoRASpecialNetwork(LoRANetwork):
# not supported
self.is_checkpointing = False
self._update_checkpointing()
@property
def is_normalizing(self) -> bool:
return self._is_normalizing
@is_normalizing.setter
def is_normalizing(self, value: bool):
self._is_normalizing = value
for module in self.get_all_modules():
module.is_normalizing = self._is_normalizing
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
for module in self.get_all_modules():
module.apply_stored_normalizer(target_normalize_scaler)

View File

@@ -38,10 +38,13 @@ SD_PREFIX_TEXT_ENCODER2 = "te2"
class BlankNetwork:
multiplier = 1.0
is_active = True
def __init__(self):
self.multiplier = 1.0
self.is_active = True
self.is_normalizing = False
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
pass
def __enter__(self):
@@ -258,6 +261,12 @@ class StableDiffusion:
else:
network = BlankNetwork()
was_network_normalizing = network.is_normalizing
# apply the normalizer if it is normalizing before inference and disable it
if network.is_normalizing:
network.apply_stored_normalizer()
network.is_normalizing = False
# save current seed state for training
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
@@ -377,6 +386,7 @@ class StableDiffusion:
if self.network is not None:
self.network.train()
self.network.multiplier = start_multiplier
self.network.is_normalizing = was_network_normalizing
# self.tokenizer.to(original_device_dict['tokenizer'])
def get_latent_noise(