mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added proper grad accumulation
This commit is contained in:
@@ -910,7 +910,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||
self.timer.start('preprocess_batch')
|
||||
batch = self.preprocess_batch(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -1243,7 +1243,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs = {}
|
||||
|
||||
if has_adapter_img:
|
||||
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
|
||||
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (
|
||||
self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
@@ -1283,7 +1284,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.train_config.do_cfg:
|
||||
embeds = [
|
||||
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0])
|
||||
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
|
||||
range(noisy_latents.shape[0])
|
||||
]
|
||||
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
||||
embeds,
|
||||
@@ -1424,7 +1426,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if prior_pred is not None:
|
||||
prior_pred = prior_pred.detach()
|
||||
|
||||
|
||||
# do the custom adapter after the prior prediction
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||
quad_count = random.randint(1, 4)
|
||||
@@ -1450,10 +1451,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.adapter.add_extra_values(batch.extra_values.detach())
|
||||
|
||||
if self.train_config.do_cfg:
|
||||
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True)
|
||||
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()),
|
||||
is_unconditional=True)
|
||||
|
||||
if has_adapter_img:
|
||||
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
|
||||
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (
|
||||
self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
|
||||
if self.train_config.do_cfg:
|
||||
raise ValueError("ControlNetModel is not supported with CFG")
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
@@ -1478,7 +1481,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
|
||||
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
||||
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
if batch.unconditional_latents is not None or self.do_guided_loss:
|
||||
@@ -1526,7 +1528,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
print("loss is nan")
|
||||
loss = torch.zeros_like(loss).requires_grad_(True)
|
||||
|
||||
|
||||
with self.timer('backward'):
|
||||
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
||||
loss = loss * loss_multiplier.mean()
|
||||
@@ -1543,8 +1544,27 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss.backward()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
return loss.detach()
|
||||
# flush()
|
||||
|
||||
def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]):
|
||||
if isinstance(batch, list):
|
||||
batch_list = batch
|
||||
else:
|
||||
batch_list = [batch]
|
||||
total_loss = None
|
||||
for batch in batch_list:
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
loss = self.train_single_accumulation(batch)
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
total_loss += loss
|
||||
if len(batch_list) > 1 and self.model_config.low_vram:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
# fix this for multi params
|
||||
if self.train_config.optimizer != 'adafactor':
|
||||
|
||||
@@ -1648,42 +1648,47 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
|
||||
if self.train_config.disable_sampling:
|
||||
is_sample_step = False
|
||||
# don't do a reg step on sample or save steps as we dont want to normalize on those
|
||||
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
|
||||
try:
|
||||
with self.timer('get_batch:reg'):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
with self.timer('reset_batch:reg'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
trigger_dataloader_setup_epoch(dataloader_reg)
|
||||
|
||||
with self.timer('get_batch:reg'):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
self.progress_bar.unpause()
|
||||
is_reg_step = True
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
with self.timer('reset_batch'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
self.epoch_num += 1
|
||||
if self.train_config.gradient_accumulation_steps == -1:
|
||||
# if we are accumulating for an entire epoch, trigger a step
|
||||
self.is_grad_accumulation_step = False
|
||||
self.grad_accumulation_step = 0
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
self.progress_bar.unpause()
|
||||
else:
|
||||
batch = None
|
||||
batch_list = []
|
||||
|
||||
for b in range(self.train_config.gradient_accumulation):
|
||||
# don't do a reg step on sample or save steps as we dont want to normalize on those
|
||||
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
|
||||
try:
|
||||
with self.timer('get_batch:reg'):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
with self.timer('reset_batch:reg'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
trigger_dataloader_setup_epoch(dataloader_reg)
|
||||
|
||||
with self.timer('get_batch:reg'):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
self.progress_bar.unpause()
|
||||
is_reg_step = True
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
with self.timer('reset_batch'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
self.epoch_num += 1
|
||||
if self.train_config.gradient_accumulation_steps == -1:
|
||||
# if we are accumulating for an entire epoch, trigger a step
|
||||
self.is_grad_accumulation_step = False
|
||||
self.grad_accumulation_step = 0
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
self.progress_bar.unpause()
|
||||
else:
|
||||
batch = None
|
||||
batch_list.append(batch)
|
||||
|
||||
# setup accumulation
|
||||
if self.train_config.gradient_accumulation_steps == -1:
|
||||
@@ -1701,7 +1706,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# flush()
|
||||
### HOOK ###
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
self.timer.stop('train_loop')
|
||||
if not did_first_flush:
|
||||
flush()
|
||||
|
||||
@@ -236,6 +236,7 @@ class TrainConfig:
|
||||
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.orig_batch_size: int = self.batch_size
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
self.xformers = kwargs.get('xformers', False)
|
||||
self.sdp = kwargs.get('sdp', False)
|
||||
@@ -284,8 +285,16 @@ class TrainConfig:
|
||||
|
||||
# set to -1 to accumulate gradients for entire epoch
|
||||
# warning, only do this with a small dataset or you will run out of memory
|
||||
# This is legacy but left in for backwards compatibility
|
||||
self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
|
||||
|
||||
# this will do proper gradient accumulation where you will not see a step until the end of the accumulation
|
||||
# the method above will show a step every accumulation
|
||||
self.gradient_accumulation = kwargs.get('gradient_accumulation', 1)
|
||||
if self.gradient_accumulation > 1:
|
||||
if self.gradient_accumulation_steps != 1:
|
||||
raise ValueError("gradient_accumulation and gradient_accumulation_steps are mutually exclusive")
|
||||
|
||||
# short long captions will double your batch size. This only works when a dataset is
|
||||
# prepared with a json caption file that has both short and long captions in it. It will
|
||||
# Double up every image and run it through with both short and long captions. The idea
|
||||
|
||||
Reference in New Issue
Block a user