mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added gradient accumulation finally
This commit is contained in:
@@ -448,16 +448,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
# with fsdp_overlap_step_with_backward():
|
||||
loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# flush()
|
||||
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
with self.timer('scheduler_step'):
|
||||
self.lr_scheduler.step()
|
||||
if not self.is_grad_accumulation_step:
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
with self.timer('scheduler_step'):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.embedding is not None:
|
||||
with self.timer('restore_embeddings'):
|
||||
|
||||
@@ -60,6 +60,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
self.epoch_num = 0
|
||||
# start at 1 so we can do a sample at the start
|
||||
self.grad_accumulation_step = 1
|
||||
# if true, then we do not do an optimizer step. We are accumulating gradients
|
||||
self.is_grad_accumulation_step = False
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.device_torch = torch.device(self.device)
|
||||
network_config = self.get_conf('network', None)
|
||||
@@ -250,7 +255,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def get_training_info(self):
|
||||
info = OrderedDict({
|
||||
'step': self.step_num + 1
|
||||
'step': self.step_num,
|
||||
'epoch': self.epoch_num,
|
||||
})
|
||||
return info
|
||||
|
||||
@@ -417,6 +423,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
||||
self.step_num = meta['training_info']['step']
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
@@ -441,6 +449,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
@@ -712,6 +722,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
@@ -1035,7 +1047,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
flush()
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
|
||||
###################################################################
|
||||
# TRAIN LOOP
|
||||
###################################################################
|
||||
|
||||
start_step_num = self.step_num
|
||||
for step in range(start_step_num, self.train_config.steps):
|
||||
self.step_num = step
|
||||
# default to true so various things can turn it off
|
||||
self.is_grad_accumulation_step = True
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
|
||||
self.progress_bar.unpause()
|
||||
@@ -1071,12 +1092,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
self.epoch_num += 1
|
||||
if self.train_config.gradient_accumulation_steps == -1:
|
||||
# if we are accumulating for an entire epoch, trigger a step
|
||||
self.is_grad_accumulation_step = False
|
||||
self.grad_accumulation_step = 0
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
self.progress_bar.unpause()
|
||||
else:
|
||||
batch = None
|
||||
|
||||
# setup accumulation
|
||||
if self.train_config.gradient_accumulation_steps == -1:
|
||||
# epoch is handling the accumulation, dont touch it
|
||||
pass
|
||||
else:
|
||||
# determine if we are accumulating or not
|
||||
# since optimizer step happens in the loop, we trigger it a step early
|
||||
# since we cannot reprocess it before them
|
||||
optimizer_step_at = self.train_config.gradient_accumulation_steps
|
||||
is_optimizer_step = self.grad_accumulation_step >= optimizer_step_at
|
||||
self.is_grad_accumulation_step = not is_optimizer_step
|
||||
if is_optimizer_step:
|
||||
self.grad_accumulation_step = 0
|
||||
|
||||
# flush()
|
||||
### HOOK ###
|
||||
self.timer.start('train_loop')
|
||||
@@ -1144,17 +1184,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# sets progress bar to match out step
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
# end of step
|
||||
self.step_num = step
|
||||
|
||||
# flush every 10 steps
|
||||
# if self.step_num % 10 == 0:
|
||||
# flush()
|
||||
#############################
|
||||
# End of step
|
||||
#############################
|
||||
|
||||
# update various steps
|
||||
self.step_num = step + 1
|
||||
self.grad_accumulation_step += 1
|
||||
|
||||
|
||||
###################################################################
|
||||
## END TRAIN LOOP
|
||||
###################################################################
|
||||
|
||||
self.progress_bar.close()
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
self.sample(self.step_num + 1)
|
||||
self.sample(self.step_num)
|
||||
print("")
|
||||
self.save()
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ class LormModuleSettingsConfig:
|
||||
class LoRMConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
|
||||
self.do_conv: bool = kwargs.get('do_conv', False)
|
||||
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
|
||||
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
|
||||
module_settings = kwargs.get('module_settings', [])
|
||||
@@ -110,6 +111,8 @@ class NetworkConfig:
|
||||
# set linear to arbitrary values so it makes them
|
||||
self.linear = 4
|
||||
self.rank = 4
|
||||
if self.lorm_config.do_conv:
|
||||
self.conv = 4
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+']
|
||||
@@ -177,6 +180,10 @@ class TrainConfig:
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
|
||||
# set to -1 to accumulate gradients for entire epoch
|
||||
# warning, only do this with a small dataset or you will run out of memory
|
||||
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
|
||||
|
||||
# short long captions will double your batch size. This only works when a dataset is
|
||||
# prepared with a json caption file that has both short and long captions in it. It will
|
||||
# Double up every image and run it through with both short and long captions. The idea
|
||||
|
||||
@@ -191,6 +191,8 @@ def extract_conv(
|
||||
if lora_rank >= out_ch / 2:
|
||||
lora_rank = int(out_ch / 2)
|
||||
print(f"rank is higher than it should be")
|
||||
# print(f"Skipping layer as determined rank is too high")
|
||||
# return None, None, None, None
|
||||
# return weight, 'full'
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
@@ -243,6 +245,8 @@ def extract_linear(
|
||||
# print(f"rank is higher than it should be")
|
||||
lora_rank = int(out_ch / 2)
|
||||
# return weight, 'full'
|
||||
# print(f"Skipping layer as determined rank is too high")
|
||||
# return None, None, None, None
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
@@ -358,6 +362,8 @@ def convert_diffusers_unet_to_lorm(
|
||||
mode_param=extract_mode_param,
|
||||
device=child_module.weight.device,
|
||||
)
|
||||
if down_weight is None:
|
||||
continue
|
||||
down_weight = down_weight.to(dtype=dtype)
|
||||
up_weight = up_weight.to(dtype=dtype)
|
||||
bias_weight = None
|
||||
@@ -398,6 +404,8 @@ def convert_diffusers_unet_to_lorm(
|
||||
mode_param=extract_mode_param,
|
||||
device=child_module.weight.device,
|
||||
)
|
||||
if down_weight is None:
|
||||
continue
|
||||
down_weight = down_weight.to(dtype=dtype)
|
||||
up_weight = up_weight.to(dtype=dtype)
|
||||
bias_weight = None
|
||||
|
||||
Reference in New Issue
Block a user