mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Fixed issue with bucket dataloader corpping in too much. Added normalization capabilities to LoRA modules. Testing effects, but should prevent them from burning and also make them more compatable with stacking many LoRAs
This commit is contained in:
@@ -208,7 +208,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.network is not None:
|
||||
prev_multiplier = self.network.multiplier
|
||||
self.network.multiplier = 1.0
|
||||
# TODO handle dreambooth, fine tuning, etc
|
||||
if self.network_config.normalize:
|
||||
# apply the normalization
|
||||
self.network.apply_stored_normalizer()
|
||||
self.network.save_weights(
|
||||
file_path,
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
@@ -323,7 +325,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
latents = self.sd.encode_images(imgs)
|
||||
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
@@ -429,6 +430,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.network.enable_gradient_checkpointing()
|
||||
|
||||
# set the network to normalize if we are
|
||||
self.network.is_normalizing = self.network_config.normalize
|
||||
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
if latest_save_path is not None:
|
||||
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
@@ -522,71 +526,84 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dataloader_reg = None
|
||||
dataloader_iterator_reg = None
|
||||
|
||||
# zero any gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
if step % 2 == 0 and dataloader_reg is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator = iter(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
else:
|
||||
batch = None
|
||||
with torch.no_grad():
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
if step % 2 == 0 and dataloader_reg is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator = iter(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
else:
|
||||
batch = None
|
||||
|
||||
# turn on normalization if we are using it and it is not on
|
||||
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
|
||||
self.network.is_normalizing = True
|
||||
|
||||
### HOOK ###
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
flush()
|
||||
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
with torch.no_grad():
|
||||
if self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
|
||||
prog_bar_string = f"lr: {learning_rate:.1e}"
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
prog_bar_string = f"lr: {learning_rate:.1e}"
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# don't do on first step
|
||||
if self.step_num != self.start_step:
|
||||
# pause progress bar
|
||||
self.progress_bar.unpause() # makes it so doesn't track time
|
||||
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
|
||||
# print above the progress bar
|
||||
self.sample(self.step_num)
|
||||
# don't do on first step
|
||||
if self.step_num != self.start_step:
|
||||
# pause progress bar
|
||||
self.progress_bar.unpause() # makes it so doesn't track time
|
||||
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
|
||||
# print above the progress bar
|
||||
self.sample(self.step_num)
|
||||
|
||||
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.refresh()
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.refresh()
|
||||
|
||||
# sets progress bar to match out step
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
# end of step
|
||||
self.step_num = step
|
||||
# sets progress bar to match out step
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
# end of step
|
||||
self.step_num = step
|
||||
|
||||
# apply network normalizer if we are using it
|
||||
if self.network is not None and self.network.is_normalizing:
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
self.sample(self.step_num + 1)
|
||||
print("")
|
||||
|
||||
@@ -48,6 +48,7 @@ class NetworkConfig:
|
||||
self.alpha: float = kwargs.get('alpha', 1.0)
|
||||
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
|
||||
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
|
||||
self.normalize = kwargs.get('normalize', False)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
|
||||
@@ -95,34 +95,40 @@ class BucketsMixin:
|
||||
# the other dimension should be the same ratio it is now (bigger)
|
||||
new_width = resolution
|
||||
new_height = resolution
|
||||
new_x = file_item.crop_x
|
||||
new_y = file_item.crop_y
|
||||
if width > height:
|
||||
# scale width to match new resolution,
|
||||
new_width = int(width * (resolution / height))
|
||||
file_item.crop_width = new_width
|
||||
file_item.scale_to_width = new_width
|
||||
file_item.crop_height = resolution
|
||||
file_item.scale_to_height = resolution
|
||||
# make sure new_width is divisible by bucket_tolerance
|
||||
if new_width % bucket_tolerance != 0:
|
||||
# reduce it to the nearest divisible number
|
||||
reduction = new_width % bucket_tolerance
|
||||
new_width = new_width - reduction
|
||||
file_item.crop_width = new_width - reduction
|
||||
# adjust the new x position so we evenly crop
|
||||
new_x = int(new_x + (reduction / 2))
|
||||
file_item.crop_x = int(file_item.crop_x + (reduction / 2))
|
||||
elif height > width:
|
||||
# scale height to match new resolution
|
||||
new_height = int(height * (resolution / width))
|
||||
file_item.crop_height = new_height
|
||||
file_item.scale_to_height = new_height
|
||||
file_item.scale_to_width = resolution
|
||||
file_item.crop_width = resolution
|
||||
# make sure new_height is divisible by bucket_tolerance
|
||||
if new_height % bucket_tolerance != 0:
|
||||
# reduce it to the nearest divisible number
|
||||
reduction = new_height % bucket_tolerance
|
||||
new_height = new_height - reduction
|
||||
file_item.crop_height = new_height - reduction
|
||||
# adjust the new x position so we evenly crop
|
||||
new_y = int(new_y + (reduction / 2))
|
||||
|
||||
# add info to file
|
||||
file_item.crop_x = new_x
|
||||
file_item.crop_y = new_y
|
||||
file_item.crop_width = new_width
|
||||
file_item.crop_height = new_height
|
||||
file_item.crop_y = int(file_item.crop_y + (reduction / 2))
|
||||
else:
|
||||
# square image
|
||||
file_item.crop_height = resolution
|
||||
file_item.scale_to_height = resolution
|
||||
file_item.scale_to_width = resolution
|
||||
file_item.crop_width = resolution
|
||||
|
||||
# check if bucket exists, if not, create it
|
||||
bucket_key = f'{new_width}x{new_height}'
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from .paths import SD_SCRIPTS_ROOT
|
||||
from .train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
@@ -78,6 +79,8 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.is_checkpointing = False
|
||||
self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
@@ -91,8 +94,8 @@ class LoRAModule(torch.nn.Module):
|
||||
batch_size = lora_up.size(0)
|
||||
# batch will have all negative prompts first and positive prompts second
|
||||
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
|
||||
# if there is more than our multiplier, it is liekly a batch size increase, so we need to
|
||||
# interleve the multipliers
|
||||
# if there is more than our multiplier, it is likely a batch size increase, so we need to
|
||||
# interleave the multipliers
|
||||
if isinstance(self.multiplier, list):
|
||||
if len(self.multiplier) == 0:
|
||||
# single item, just return it
|
||||
@@ -153,25 +156,30 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
return lx * multiplier * scale
|
||||
|
||||
def create_custom_forward(self):
|
||||
def custom_forward(*inputs):
|
||||
return self._call_forward(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
# TODO this just loses the grad. Not sure why. Probably why no one else is doing it either
|
||||
# if torch.is_grad_enabled() and self.is_checkpointing and self.training:
|
||||
# lora_output = checkpoint(
|
||||
# self.create_custom_forward(),
|
||||
# x,
|
||||
# )
|
||||
# else:
|
||||
# lora_output = self._call_forward(x)
|
||||
|
||||
lora_output = self._call_forward(x)
|
||||
|
||||
if self.is_normalizing:
|
||||
# get a dim array from orig forward that had index of all dimensions except the batch and channel
|
||||
|
||||
# Calculate the target magnitude for the combined output
|
||||
orig_max = torch.max(torch.abs(org_forwarded))
|
||||
|
||||
# Calculate the additional increase in magnitude that lora_output would introduce
|
||||
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output) - torch.abs(org_forwarded))
|
||||
|
||||
epsilon = 1e-6 # Small constant to avoid division by zero
|
||||
|
||||
# Calculate the scaling factor for the lora_output
|
||||
# to ensure that the potential increase in magnitude doesn't change the original max
|
||||
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
|
||||
|
||||
# save the scaler so it can be applied later
|
||||
self.normalize_scaler = normalize_scaler.clone().detach()
|
||||
|
||||
lora_output *= normalize_scaler
|
||||
|
||||
return org_forwarded + lora_output
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
@@ -180,11 +188,39 @@ class LoRAModule(torch.nn.Module):
|
||||
def disable_gradient_checkpointing(self):
|
||||
self.is_checkpointing = False
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
|
||||
"""
|
||||
Applied the previous normalization calculation to the module.
|
||||
This must be called before saving or normalization will be lost.
|
||||
It is probably best to call after each batch as well.
|
||||
We just scale the up down weights to match this vector
|
||||
:return:
|
||||
"""
|
||||
# get state dict
|
||||
state_dict = self.state_dict()
|
||||
dtype = state_dict['lora_up.weight'].dtype
|
||||
device = state_dict['lora_up.weight'].device
|
||||
|
||||
# todo should we do this at fp32?
|
||||
|
||||
total_module_scale = torch.tensor(self.normalize_scaler / target_normalize_scaler) \
|
||||
.to(device, dtype=dtype)
|
||||
num_modules_layers = 2 # up and down
|
||||
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
|
||||
.to(device, dtype=dtype)
|
||||
|
||||
# apply the scaler to the up and down weights
|
||||
for key in state_dict.keys():
|
||||
if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
|
||||
# do it inplace do params are updated
|
||||
state_dict[key] *= up_down_scale
|
||||
|
||||
# reset the normalization scaler
|
||||
self.normalize_scaler = target_normalize_scaler
|
||||
|
||||
|
||||
class LoRASpecialNetwork(LoRANetwork):
|
||||
_multiplier: float = 1.0
|
||||
is_active: bool = False
|
||||
|
||||
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
@@ -230,7 +266,6 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
"""
|
||||
# call the parent of the parent we are replacing (LoRANetwork) init
|
||||
super(LoRANetwork, self).__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
@@ -240,6 +275,11 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.is_checkpointing = False
|
||||
self._multiplier: float = 1.0
|
||||
self.is_active: bool = False
|
||||
self._is_normalizing: bool = False
|
||||
# triggers the state updates
|
||||
self.multiplier = multiplier
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
@@ -451,21 +491,20 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
for lora in loras:
|
||||
lora.to(device, dtype)
|
||||
|
||||
def get_all_modules(self):
|
||||
loras = []
|
||||
if hasattr(self, 'unet_loras'):
|
||||
loras += self.unet_loras
|
||||
if hasattr(self, 'text_encoder_loras'):
|
||||
loras += self.text_encoder_loras
|
||||
return loras
|
||||
|
||||
def _update_checkpointing(self):
|
||||
if self.is_checkpointing:
|
||||
if hasattr(self, 'unet_loras'):
|
||||
for lora in self.unet_loras:
|
||||
lora.enable_gradient_checkpointing()
|
||||
if hasattr(self, 'text_encoder_loras'):
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.enable_gradient_checkpointing()
|
||||
else:
|
||||
if hasattr(self, 'unet_loras'):
|
||||
for lora in self.unet_loras:
|
||||
lora.disable_gradient_checkpointing()
|
||||
if hasattr(self, 'text_encoder_loras'):
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.disable_gradient_checkpointing()
|
||||
for module in self.get_all_modules():
|
||||
if self.is_checkpointing:
|
||||
module.enable_gradient_checkpointing()
|
||||
else:
|
||||
module.disable_gradient_checkpointing()
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
@@ -476,3 +515,17 @@ class LoRASpecialNetwork(LoRANetwork):
|
||||
# not supported
|
||||
self.is_checkpointing = False
|
||||
self._update_checkpointing()
|
||||
|
||||
@property
|
||||
def is_normalizing(self) -> bool:
|
||||
return self._is_normalizing
|
||||
|
||||
@is_normalizing.setter
|
||||
def is_normalizing(self, value: bool):
|
||||
self._is_normalizing = value
|
||||
for module in self.get_all_modules():
|
||||
module.is_normalizing = self._is_normalizing
|
||||
|
||||
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
|
||||
for module in self.get_all_modules():
|
||||
module.apply_stored_normalizer(target_normalize_scaler)
|
||||
|
||||
@@ -38,10 +38,13 @@ SD_PREFIX_TEXT_ENCODER2 = "te2"
|
||||
|
||||
|
||||
class BlankNetwork:
|
||||
multiplier = 1.0
|
||||
is_active = True
|
||||
|
||||
def __init__(self):
|
||||
self.multiplier = 1.0
|
||||
self.is_active = True
|
||||
self.is_normalizing = False
|
||||
|
||||
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
@@ -258,6 +261,12 @@ class StableDiffusion:
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
|
||||
was_network_normalizing = network.is_normalizing
|
||||
# apply the normalizer if it is normalizing before inference and disable it
|
||||
if network.is_normalizing:
|
||||
network.apply_stored_normalizer()
|
||||
network.is_normalizing = False
|
||||
|
||||
# save current seed state for training
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
@@ -377,6 +386,7 @@ class StableDiffusion:
|
||||
if self.network is not None:
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
self.network.is_normalizing = was_network_normalizing
|
||||
# self.tokenizer.to(original_device_dict['tokenizer'])
|
||||
|
||||
def get_latent_noise(
|
||||
|
||||
Reference in New Issue
Block a user