Added proper grad accumulation

This commit is contained in:
Jaret Burkett
2024-09-03 07:24:18 -06:00
parent e5fadddd45
commit 121a760c19
3 changed files with 78 additions and 44 deletions

View File

@@ -910,7 +910,7 @@ class SDTrainer(BaseSDTrainProcess):
**kwargs
)
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)
@@ -1243,7 +1243,8 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs = {}
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (
self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
@@ -1283,7 +1284,8 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.do_cfg:
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0])
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
range(noisy_latents.shape[0])
]
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
embeds,
@@ -1424,7 +1426,6 @@ class SDTrainer(BaseSDTrainProcess):
if prior_pred is not None:
prior_pred = prior_pred.detach()
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
quad_count = random.randint(1, 4)
@@ -1450,10 +1451,12 @@ class SDTrainer(BaseSDTrainProcess):
self.adapter.add_extra_values(batch.extra_values.detach())
if self.train_config.do_cfg:
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True)
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()),
is_unconditional=True)
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (
self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
if self.train_config.do_cfg:
raise ValueError("ControlNetModel is not supported with CFG")
with torch.set_grad_enabled(self.adapter is not None):
@@ -1478,7 +1481,6 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None or self.do_guided_loss:
@@ -1526,7 +1528,6 @@ class SDTrainer(BaseSDTrainProcess):
print("loss is nan")
loss = torch.zeros_like(loss).requires_grad_(True)
with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
loss = loss * loss_multiplier.mean()
@@ -1543,8 +1544,27 @@ class SDTrainer(BaseSDTrainProcess):
loss.backward()
else:
self.scaler.scale(loss).backward()
return loss.detach()
# flush()
def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]):
if isinstance(batch, list):
batch_list = batch
else:
batch_list = [batch]
total_loss = None
for batch in batch_list:
self.optimizer.zero_grad(set_to_none=True)
loss = self.train_single_accumulation(batch)
if total_loss is None:
total_loss = loss
else:
total_loss += loss
if len(batch_list) > 1 and self.model_config.low_vram:
torch.cuda.empty_cache()
if not self.is_grad_accumulation_step:
# fix this for multi params
if self.train_config.optimizer != 'adafactor':

View File

@@ -1648,42 +1648,47 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
if self.train_config.disable_sampling:
is_sample_step = False
# don't do a reg step on sample or save steps as we dont want to normalize on those
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
try:
with self.timer('get_batch:reg'):
batch = next(dataloader_iterator_reg)
except StopIteration:
with self.timer('reset_batch:reg'):
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator_reg = iter(dataloader_reg)
trigger_dataloader_setup_epoch(dataloader_reg)
with self.timer('get_batch:reg'):
batch = next(dataloader_iterator_reg)
self.progress_bar.unpause()
is_reg_step = True
elif dataloader is not None:
try:
with self.timer('get_batch'):
batch = next(dataloader_iterator)
except StopIteration:
with self.timer('reset_batch'):
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
self.epoch_num += 1
if self.train_config.gradient_accumulation_steps == -1:
# if we are accumulating for an entire epoch, trigger a step
self.is_grad_accumulation_step = False
self.grad_accumulation_step = 0
with self.timer('get_batch'):
batch = next(dataloader_iterator)
self.progress_bar.unpause()
else:
batch = None
batch_list = []
for b in range(self.train_config.gradient_accumulation):
# don't do a reg step on sample or save steps as we dont want to normalize on those
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
try:
with self.timer('get_batch:reg'):
batch = next(dataloader_iterator_reg)
except StopIteration:
with self.timer('reset_batch:reg'):
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator_reg = iter(dataloader_reg)
trigger_dataloader_setup_epoch(dataloader_reg)
with self.timer('get_batch:reg'):
batch = next(dataloader_iterator_reg)
self.progress_bar.unpause()
is_reg_step = True
elif dataloader is not None:
try:
with self.timer('get_batch'):
batch = next(dataloader_iterator)
except StopIteration:
with self.timer('reset_batch'):
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
self.epoch_num += 1
if self.train_config.gradient_accumulation_steps == -1:
# if we are accumulating for an entire epoch, trigger a step
self.is_grad_accumulation_step = False
self.grad_accumulation_step = 0
with self.timer('get_batch'):
batch = next(dataloader_iterator)
self.progress_bar.unpause()
else:
batch = None
batch_list.append(batch)
# setup accumulation
if self.train_config.gradient_accumulation_steps == -1:
@@ -1701,7 +1706,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# flush()
### HOOK ###
loss_dict = self.hook_train_loop(batch)
loss_dict = self.hook_train_loop(batch_list)
self.timer.stop('train_loop')
if not did_first_flush:
flush()

View File

@@ -236,6 +236,7 @@ class TrainConfig:
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
self.batch_size: int = kwargs.get('batch_size', 1)
self.orig_batch_size: int = self.batch_size
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)
self.sdp = kwargs.get('sdp', False)
@@ -284,8 +285,16 @@ class TrainConfig:
# set to -1 to accumulate gradients for entire epoch
# warning, only do this with a small dataset or you will run out of memory
# This is legacy but left in for backwards compatibility
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
# this will do proper gradient accumulation where you will not see a step until the end of the accumulation
# the method above will show a step every accumulation
self.gradient_accumulation = kwargs.get('gradient_accumulation', 1)
if self.gradient_accumulation > 1:
if self.gradient_accumulation_steps != 1:
raise ValueError("gradient_accumulation and gradient_accumulation_steps are mutually exclusive")
# short long captions will double your batch size. This only works when a dataset is
# prepared with a json caption file that has both short and long captions in it. It will
# Double up every image and run it through with both short and long captions. The idea