Added gradient accumulation finally

This commit is contained in:
Jaret Burkett
2023-10-28 13:14:29 -06:00
parent 6f3e0d5af2
commit 298001439a
4 changed files with 79 additions and 16 deletions

View File

@@ -448,16 +448,17 @@ class SDTrainer(BaseSDTrainProcess):
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
loss.backward()
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# flush()
with self.timer('optimizer_step'):
# apply gradients
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
with self.timer('scheduler_step'):
self.lr_scheduler.step()
if not self.is_grad_accumulation_step:
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# only step if we are not accumulating
with self.timer('optimizer_step'):
# apply gradients
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
with self.timer('scheduler_step'):
self.lr_scheduler.step()
if self.embedding is not None:
with self.timer('restore_embeddings'):

View File

@@ -60,6 +60,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.custom_pipeline = custom_pipeline
self.step_num = 0
self.start_step = 0
self.epoch_num = 0
# start at 1 so we can do a sample at the start
self.grad_accumulation_step = 1
# if true, then we do not do an optimizer step. We are accumulating gradients
self.is_grad_accumulation_step = False
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
network_config = self.get_conf('network', None)
@@ -250,7 +255,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
def get_training_info(self):
info = OrderedDict({
'step': self.step_num + 1
'step': self.step_num,
'epoch': self.epoch_num,
})
return info
@@ -417,6 +423,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
self.step_num = meta['training_info']['step']
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
@@ -441,6 +449,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
@@ -712,6 +722,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
@@ -1035,7 +1047,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.set_device_state(self.train_device_state_preset)
flush()
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
###################################################################
# TRAIN LOOP
###################################################################
start_step_num = self.step_num
for step in range(start_step_num, self.train_config.steps):
self.step_num = step
# default to true so various things can turn it off
self.is_grad_accumulation_step = True
if self.train_config.free_u:
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
self.progress_bar.unpause()
@@ -1071,12 +1092,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
self.epoch_num += 1
if self.train_config.gradient_accumulation_steps == -1:
# if we are accumulating for an entire epoch, trigger a step
self.is_grad_accumulation_step = False
self.grad_accumulation_step = 0
with self.timer('get_batch'):
batch = next(dataloader_iterator)
self.progress_bar.unpause()
else:
batch = None
# setup accumulation
if self.train_config.gradient_accumulation_steps == -1:
# epoch is handling the accumulation, dont touch it
pass
else:
# determine if we are accumulating or not
# since optimizer step happens in the loop, we trigger it a step early
# since we cannot reprocess it before them
optimizer_step_at = self.train_config.gradient_accumulation_steps
is_optimizer_step = self.grad_accumulation_step >= optimizer_step_at
self.is_grad_accumulation_step = not is_optimizer_step
if is_optimizer_step:
self.grad_accumulation_step = 0
# flush()
### HOOK ###
self.timer.start('train_loop')
@@ -1144,17 +1184,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
# sets progress bar to match out step
self.progress_bar.update(step - self.progress_bar.n)
# end of step
self.step_num = step
# flush every 10 steps
# if self.step_num % 10 == 0:
# flush()
#############################
# End of step
#############################
# update various steps
self.step_num = step + 1
self.grad_accumulation_step += 1
###################################################################
## END TRAIN LOOP
###################################################################
self.progress_bar.close()
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
self.sample(self.step_num + 1)
self.sample(self.step_num)
print("")
self.save()

View File

@@ -52,6 +52,7 @@ class LormModuleSettingsConfig:
class LoRMConfig:
def __init__(self, **kwargs):
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
self.do_conv: bool = kwargs.get('do_conv', False)
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
module_settings = kwargs.get('module_settings', [])
@@ -110,6 +111,8 @@ class NetworkConfig:
# set linear to arbitrary values so it makes them
self.linear = 4
self.rank = 4
if self.lorm_config.do_conv:
self.conv = 4
AdapterTypes = Literal['t2i', 'ip', 'ip+']
@@ -177,6 +180,10 @@ class TrainConfig:
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
# set to -1 to accumulate gradients for entire epoch
# warning, only do this with a small dataset or you will run out of memory
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
# short long captions will double your batch size. This only works when a dataset is
# prepared with a json caption file that has both short and long captions in it. It will
# Double up every image and run it through with both short and long captions. The idea

View File

@@ -191,6 +191,8 @@ def extract_conv(
if lora_rank >= out_ch / 2:
lora_rank = int(out_ch / 2)
print(f"rank is higher than it should be")
# print(f"Skipping layer as determined rank is too high")
# return None, None, None, None
# return weight, 'full'
U = U[:, :lora_rank]
@@ -243,6 +245,8 @@ def extract_linear(
# print(f"rank is higher than it should be")
lora_rank = int(out_ch / 2)
# return weight, 'full'
# print(f"Skipping layer as determined rank is too high")
# return None, None, None, None
U = U[:, :lora_rank]
S = S[:lora_rank]
@@ -358,6 +362,8 @@ def convert_diffusers_unet_to_lorm(
mode_param=extract_mode_param,
device=child_module.weight.device,
)
if down_weight is None:
continue
down_weight = down_weight.to(dtype=dtype)
up_weight = up_weight.to(dtype=dtype)
bias_weight = None
@@ -398,6 +404,8 @@ def convert_diffusers_unet_to_lorm(
mode_param=extract_mode_param,
device=child_module.weight.device,
)
if down_weight is None:
continue
down_weight = down_weight.to(dtype=dtype)
up_weight = up_weight.to(dtype=dtype)
bias_weight = None