mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Added gradient accumulation finally
This commit is contained in:
@@ -60,6 +60,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.custom_pipeline = custom_pipeline
|
||||
self.step_num = 0
|
||||
self.start_step = 0
|
||||
self.epoch_num = 0
|
||||
# start at 1 so we can do a sample at the start
|
||||
self.grad_accumulation_step = 1
|
||||
# if true, then we do not do an optimizer step. We are accumulating gradients
|
||||
self.is_grad_accumulation_step = False
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.device_torch = torch.device(self.device)
|
||||
network_config = self.get_conf('network', None)
|
||||
@@ -250,7 +255,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def get_training_info(self):
|
||||
info = OrderedDict({
|
||||
'step': self.step_num + 1
|
||||
'step': self.step_num,
|
||||
'epoch': self.epoch_num,
|
||||
})
|
||||
return info
|
||||
|
||||
@@ -417,6 +423,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
||||
self.step_num = meta['training_info']['step']
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
@@ -441,6 +449,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
@@ -712,6 +722,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
if 'epoch' in meta['training_info']:
|
||||
self.epoch_num = meta['training_info']['epoch']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
@@ -1035,7 +1047,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
flush()
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
|
||||
###################################################################
|
||||
# TRAIN LOOP
|
||||
###################################################################
|
||||
|
||||
start_step_num = self.step_num
|
||||
for step in range(start_step_num, self.train_config.steps):
|
||||
self.step_num = step
|
||||
# default to true so various things can turn it off
|
||||
self.is_grad_accumulation_step = True
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
|
||||
self.progress_bar.unpause()
|
||||
@@ -1071,12 +1092,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
self.epoch_num += 1
|
||||
if self.train_config.gradient_accumulation_steps == -1:
|
||||
# if we are accumulating for an entire epoch, trigger a step
|
||||
self.is_grad_accumulation_step = False
|
||||
self.grad_accumulation_step = 0
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
self.progress_bar.unpause()
|
||||
else:
|
||||
batch = None
|
||||
|
||||
# setup accumulation
|
||||
if self.train_config.gradient_accumulation_steps == -1:
|
||||
# epoch is handling the accumulation, dont touch it
|
||||
pass
|
||||
else:
|
||||
# determine if we are accumulating or not
|
||||
# since optimizer step happens in the loop, we trigger it a step early
|
||||
# since we cannot reprocess it before them
|
||||
optimizer_step_at = self.train_config.gradient_accumulation_steps
|
||||
is_optimizer_step = self.grad_accumulation_step >= optimizer_step_at
|
||||
self.is_grad_accumulation_step = not is_optimizer_step
|
||||
if is_optimizer_step:
|
||||
self.grad_accumulation_step = 0
|
||||
|
||||
# flush()
|
||||
### HOOK ###
|
||||
self.timer.start('train_loop')
|
||||
@@ -1144,17 +1184,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# sets progress bar to match out step
|
||||
self.progress_bar.update(step - self.progress_bar.n)
|
||||
# end of step
|
||||
self.step_num = step
|
||||
|
||||
# flush every 10 steps
|
||||
# if self.step_num % 10 == 0:
|
||||
# flush()
|
||||
#############################
|
||||
# End of step
|
||||
#############################
|
||||
|
||||
# update various steps
|
||||
self.step_num = step + 1
|
||||
self.grad_accumulation_step += 1
|
||||
|
||||
|
||||
###################################################################
|
||||
## END TRAIN LOOP
|
||||
###################################################################
|
||||
|
||||
self.progress_bar.close()
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
self.sample(self.step_num + 1)
|
||||
self.sample(self.step_num)
|
||||
print("")
|
||||
self.save()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user