mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added gradient accumulation finally
This commit is contained in:
@@ -448,16 +448,17 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# I spent weeks on fighting this. DON'T DO IT
|
# I spent weeks on fighting this. DON'T DO IT
|
||||||
# with fsdp_overlap_step_with_backward():
|
# with fsdp_overlap_step_with_backward():
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
|
||||||
# flush()
|
# flush()
|
||||||
|
|
||||||
with self.timer('optimizer_step'):
|
if not self.is_grad_accumulation_step:
|
||||||
# apply gradients
|
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||||
self.optimizer.step()
|
# only step if we are not accumulating
|
||||||
self.optimizer.zero_grad(set_to_none=True)
|
with self.timer('optimizer_step'):
|
||||||
with self.timer('scheduler_step'):
|
# apply gradients
|
||||||
self.lr_scheduler.step()
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad(set_to_none=True)
|
||||||
|
with self.timer('scheduler_step'):
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
|
||||||
if self.embedding is not None:
|
if self.embedding is not None:
|
||||||
with self.timer('restore_embeddings'):
|
with self.timer('restore_embeddings'):
|
||||||
|
|||||||
@@ -60,6 +60,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.custom_pipeline = custom_pipeline
|
self.custom_pipeline = custom_pipeline
|
||||||
self.step_num = 0
|
self.step_num = 0
|
||||||
self.start_step = 0
|
self.start_step = 0
|
||||||
|
self.epoch_num = 0
|
||||||
|
# start at 1 so we can do a sample at the start
|
||||||
|
self.grad_accumulation_step = 1
|
||||||
|
# if true, then we do not do an optimizer step. We are accumulating gradients
|
||||||
|
self.is_grad_accumulation_step = False
|
||||||
self.device = self.get_conf('device', self.job.device)
|
self.device = self.get_conf('device', self.job.device)
|
||||||
self.device_torch = torch.device(self.device)
|
self.device_torch = torch.device(self.device)
|
||||||
network_config = self.get_conf('network', None)
|
network_config = self.get_conf('network', None)
|
||||||
@@ -250,7 +255,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
def get_training_info(self):
|
def get_training_info(self):
|
||||||
info = OrderedDict({
|
info = OrderedDict({
|
||||||
'step': self.step_num + 1
|
'step': self.step_num,
|
||||||
|
'epoch': self.epoch_num,
|
||||||
})
|
})
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@@ -417,6 +423,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# if 'training_info' in Orderdict keys
|
# if 'training_info' in Orderdict keys
|
||||||
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
if 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
|
||||||
self.step_num = meta['training_info']['step']
|
self.step_num = meta['training_info']['step']
|
||||||
|
if 'epoch' in meta['training_info']:
|
||||||
|
self.epoch_num = meta['training_info']['epoch']
|
||||||
self.start_step = self.step_num
|
self.start_step = self.step_num
|
||||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||||
|
|
||||||
@@ -441,6 +449,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# if 'training_info' in Orderdict keys
|
# if 'training_info' in Orderdict keys
|
||||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||||
self.step_num = meta['training_info']['step']
|
self.step_num = meta['training_info']['step']
|
||||||
|
if 'epoch' in meta['training_info']:
|
||||||
|
self.epoch_num = meta['training_info']['epoch']
|
||||||
self.start_step = self.step_num
|
self.start_step = self.step_num
|
||||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||||
|
|
||||||
@@ -712,6 +722,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# if 'training_info' in Orderdict keys
|
# if 'training_info' in Orderdict keys
|
||||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||||
self.step_num = meta['training_info']['step']
|
self.step_num = meta['training_info']['step']
|
||||||
|
if 'epoch' in meta['training_info']:
|
||||||
|
self.epoch_num = meta['training_info']['epoch']
|
||||||
self.start_step = self.step_num
|
self.start_step = self.step_num
|
||||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||||
|
|
||||||
@@ -1035,7 +1047,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.sd.set_device_state(self.train_device_state_preset)
|
self.sd.set_device_state(self.train_device_state_preset)
|
||||||
flush()
|
flush()
|
||||||
# self.step_num = 0
|
# self.step_num = 0
|
||||||
for step in range(self.step_num, self.train_config.steps):
|
|
||||||
|
###################################################################
|
||||||
|
# TRAIN LOOP
|
||||||
|
###################################################################
|
||||||
|
|
||||||
|
start_step_num = self.step_num
|
||||||
|
for step in range(start_step_num, self.train_config.steps):
|
||||||
|
self.step_num = step
|
||||||
|
# default to true so various things can turn it off
|
||||||
|
self.is_grad_accumulation_step = True
|
||||||
if self.train_config.free_u:
|
if self.train_config.free_u:
|
||||||
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
|
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
|
||||||
self.progress_bar.unpause()
|
self.progress_bar.unpause()
|
||||||
@@ -1071,12 +1092,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.progress_bar.pause()
|
self.progress_bar.pause()
|
||||||
dataloader_iterator = iter(dataloader)
|
dataloader_iterator = iter(dataloader)
|
||||||
trigger_dataloader_setup_epoch(dataloader)
|
trigger_dataloader_setup_epoch(dataloader)
|
||||||
|
self.epoch_num += 1
|
||||||
|
if self.train_config.gradient_accumulation_steps == -1:
|
||||||
|
# if we are accumulating for an entire epoch, trigger a step
|
||||||
|
self.is_grad_accumulation_step = False
|
||||||
|
self.grad_accumulation_step = 0
|
||||||
with self.timer('get_batch'):
|
with self.timer('get_batch'):
|
||||||
batch = next(dataloader_iterator)
|
batch = next(dataloader_iterator)
|
||||||
self.progress_bar.unpause()
|
self.progress_bar.unpause()
|
||||||
else:
|
else:
|
||||||
batch = None
|
batch = None
|
||||||
|
|
||||||
|
# setup accumulation
|
||||||
|
if self.train_config.gradient_accumulation_steps == -1:
|
||||||
|
# epoch is handling the accumulation, dont touch it
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# determine if we are accumulating or not
|
||||||
|
# since optimizer step happens in the loop, we trigger it a step early
|
||||||
|
# since we cannot reprocess it before them
|
||||||
|
optimizer_step_at = self.train_config.gradient_accumulation_steps
|
||||||
|
is_optimizer_step = self.grad_accumulation_step >= optimizer_step_at
|
||||||
|
self.is_grad_accumulation_step = not is_optimizer_step
|
||||||
|
if is_optimizer_step:
|
||||||
|
self.grad_accumulation_step = 0
|
||||||
|
|
||||||
# flush()
|
# flush()
|
||||||
### HOOK ###
|
### HOOK ###
|
||||||
self.timer.start('train_loop')
|
self.timer.start('train_loop')
|
||||||
@@ -1144,17 +1184,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
# sets progress bar to match out step
|
# sets progress bar to match out step
|
||||||
self.progress_bar.update(step - self.progress_bar.n)
|
self.progress_bar.update(step - self.progress_bar.n)
|
||||||
# end of step
|
|
||||||
self.step_num = step
|
|
||||||
|
|
||||||
# flush every 10 steps
|
#############################
|
||||||
# if self.step_num % 10 == 0:
|
# End of step
|
||||||
# flush()
|
#############################
|
||||||
|
|
||||||
|
# update various steps
|
||||||
|
self.step_num = step + 1
|
||||||
|
self.grad_accumulation_step += 1
|
||||||
|
|
||||||
|
|
||||||
|
###################################################################
|
||||||
|
## END TRAIN LOOP
|
||||||
|
###################################################################
|
||||||
|
|
||||||
self.progress_bar.close()
|
self.progress_bar.close()
|
||||||
if self.train_config.free_u:
|
if self.train_config.free_u:
|
||||||
self.sd.pipeline.disable_freeu()
|
self.sd.pipeline.disable_freeu()
|
||||||
self.sample(self.step_num + 1)
|
self.sample(self.step_num)
|
||||||
print("")
|
print("")
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class LormModuleSettingsConfig:
|
|||||||
class LoRMConfig:
|
class LoRMConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
|
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
|
||||||
|
self.do_conv: bool = kwargs.get('do_conv', False)
|
||||||
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
|
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
|
||||||
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
|
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
|
||||||
module_settings = kwargs.get('module_settings', [])
|
module_settings = kwargs.get('module_settings', [])
|
||||||
@@ -110,6 +111,8 @@ class NetworkConfig:
|
|||||||
# set linear to arbitrary values so it makes them
|
# set linear to arbitrary values so it makes them
|
||||||
self.linear = 4
|
self.linear = 4
|
||||||
self.rank = 4
|
self.rank = 4
|
||||||
|
if self.lorm_config.do_conv:
|
||||||
|
self.conv = 4
|
||||||
|
|
||||||
|
|
||||||
AdapterTypes = Literal['t2i', 'ip', 'ip+']
|
AdapterTypes = Literal['t2i', 'ip', 'ip+']
|
||||||
@@ -177,6 +180,10 @@ class TrainConfig:
|
|||||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||||
|
|
||||||
|
# set to -1 to accumulate gradients for entire epoch
|
||||||
|
# warning, only do this with a small dataset or you will run out of memory
|
||||||
|
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
|
||||||
|
|
||||||
# short long captions will double your batch size. This only works when a dataset is
|
# short long captions will double your batch size. This only works when a dataset is
|
||||||
# prepared with a json caption file that has both short and long captions in it. It will
|
# prepared with a json caption file that has both short and long captions in it. It will
|
||||||
# Double up every image and run it through with both short and long captions. The idea
|
# Double up every image and run it through with both short and long captions. The idea
|
||||||
|
|||||||
@@ -191,6 +191,8 @@ def extract_conv(
|
|||||||
if lora_rank >= out_ch / 2:
|
if lora_rank >= out_ch / 2:
|
||||||
lora_rank = int(out_ch / 2)
|
lora_rank = int(out_ch / 2)
|
||||||
print(f"rank is higher than it should be")
|
print(f"rank is higher than it should be")
|
||||||
|
# print(f"Skipping layer as determined rank is too high")
|
||||||
|
# return None, None, None, None
|
||||||
# return weight, 'full'
|
# return weight, 'full'
|
||||||
|
|
||||||
U = U[:, :lora_rank]
|
U = U[:, :lora_rank]
|
||||||
@@ -243,6 +245,8 @@ def extract_linear(
|
|||||||
# print(f"rank is higher than it should be")
|
# print(f"rank is higher than it should be")
|
||||||
lora_rank = int(out_ch / 2)
|
lora_rank = int(out_ch / 2)
|
||||||
# return weight, 'full'
|
# return weight, 'full'
|
||||||
|
# print(f"Skipping layer as determined rank is too high")
|
||||||
|
# return None, None, None, None
|
||||||
|
|
||||||
U = U[:, :lora_rank]
|
U = U[:, :lora_rank]
|
||||||
S = S[:lora_rank]
|
S = S[:lora_rank]
|
||||||
@@ -358,6 +362,8 @@ def convert_diffusers_unet_to_lorm(
|
|||||||
mode_param=extract_mode_param,
|
mode_param=extract_mode_param,
|
||||||
device=child_module.weight.device,
|
device=child_module.weight.device,
|
||||||
)
|
)
|
||||||
|
if down_weight is None:
|
||||||
|
continue
|
||||||
down_weight = down_weight.to(dtype=dtype)
|
down_weight = down_weight.to(dtype=dtype)
|
||||||
up_weight = up_weight.to(dtype=dtype)
|
up_weight = up_weight.to(dtype=dtype)
|
||||||
bias_weight = None
|
bias_weight = None
|
||||||
@@ -398,6 +404,8 @@ def convert_diffusers_unet_to_lorm(
|
|||||||
mode_param=extract_mode_param,
|
mode_param=extract_mode_param,
|
||||||
device=child_module.weight.device,
|
device=child_module.weight.device,
|
||||||
)
|
)
|
||||||
|
if down_weight is None:
|
||||||
|
continue
|
||||||
down_weight = down_weight.to(dtype=dtype)
|
down_weight = down_weight.to(dtype=dtype)
|
||||||
up_weight = up_weight.to(dtype=dtype)
|
up_weight = up_weight.to(dtype=dtype)
|
||||||
bias_weight = None
|
bias_weight = None
|
||||||
|
|||||||
Reference in New Issue
Block a user