Added gradient accumulation finally

This commit is contained in:
Jaret Burkett
2023-10-28 13:14:29 -06:00
parent 6f3e0d5af2
commit 298001439a
4 changed files with 79 additions and 16 deletions

View File

@@ -60,6 +60,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.custom_pipeline = custom_pipeline
self.step_num = 0
self.start_step = 0
self.epoch_num = 0
# start at 1 so we can do a sample at the start
self.grad_accumulation_step = 1
# if true, then we do not do an optimizer step. We are accumulating gradients
self.is_grad_accumulation_step = False
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
network_config = self.get_conf('network', None)
@@ -250,7 +255,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
def get_training_info(self):
info = OrderedDict({
'step': self.step_num + 1
'step': self.step_num,
'epoch': self.epoch_num,
})
return info
@@ -417,6 +423,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
self.step_num = meta['training_info']['step']
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
@@ -441,6 +449,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
@@ -712,6 +722,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
if 'epoch' in meta['training_info']:
self.epoch_num = meta['training_info']['epoch']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
@@ -1035,7 +1047,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.set_device_state(self.train_device_state_preset)
flush()
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
###################################################################
# TRAIN LOOP
###################################################################
start_step_num = self.step_num
for step in range(start_step_num, self.train_config.steps):
self.step_num = step
# default to true so various things can turn it off
self.is_grad_accumulation_step = True
if self.train_config.free_u:
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
self.progress_bar.unpause()
@@ -1071,12 +1092,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
self.epoch_num += 1
if self.train_config.gradient_accumulation_steps == -1:
# if we are accumulating for an entire epoch, trigger a step
self.is_grad_accumulation_step = False
self.grad_accumulation_step = 0
with self.timer('get_batch'):
batch = next(dataloader_iterator)
self.progress_bar.unpause()
else:
batch = None
# setup accumulation
if self.train_config.gradient_accumulation_steps == -1:
# epoch is handling the accumulation, dont touch it
pass
else:
# determine if we are accumulating or not
# since optimizer step happens in the loop, we trigger it a step early
# since we cannot reprocess it before them
optimizer_step_at = self.train_config.gradient_accumulation_steps
is_optimizer_step = self.grad_accumulation_step >= optimizer_step_at
self.is_grad_accumulation_step = not is_optimizer_step
if is_optimizer_step:
self.grad_accumulation_step = 0
# flush()
### HOOK ###
self.timer.start('train_loop')
@@ -1144,17 +1184,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
# sets progress bar to match out step
self.progress_bar.update(step - self.progress_bar.n)
# end of step
self.step_num = step
# flush every 10 steps
# if self.step_num % 10 == 0:
# flush()
#############################
# End of step
#############################
# update various steps
self.step_num = step + 1
self.grad_accumulation_step += 1
###################################################################
## END TRAIN LOOP
###################################################################
self.progress_bar.close()
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
self.sample(self.step_num + 1)
self.sample(self.step_num)
print("")
self.save()