Added proper grad accumulation

This commit is contained in:
Jaret Burkett
2024-09-03 07:24:18 -06:00
parent e5fadddd45
commit 121a760c19
3 changed files with 78 additions and 44 deletions

View File

@@ -910,7 +910,7 @@ class SDTrainer(BaseSDTrainProcess):
**kwargs
)
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)
@@ -1243,7 +1243,8 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs = {}
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (
self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
with torch.set_grad_enabled(self.adapter is not None):
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
adapter_multiplier = get_adapter_multiplier()
@@ -1283,7 +1284,8 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.do_cfg:
embeds = [
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0])
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
range(noisy_latents.shape[0])
]
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
embeds,
@@ -1424,7 +1426,6 @@ class SDTrainer(BaseSDTrainProcess):
if prior_pred is not None:
prior_pred = prior_pred.detach()
# do the custom adapter after the prior prediction
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
quad_count = random.randint(1, 4)
@@ -1450,10 +1451,12 @@ class SDTrainer(BaseSDTrainProcess):
self.adapter.add_extra_values(batch.extra_values.detach())
if self.train_config.do_cfg:
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True)
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()),
is_unconditional=True)
if has_adapter_img:
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (
self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
if self.train_config.do_cfg:
raise ValueError("ControlNetModel is not supported with CFG")
with torch.set_grad_enabled(self.adapter is not None):
@@ -1478,7 +1481,6 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later
if batch.unconditional_latents is not None or self.do_guided_loss:
@@ -1526,7 +1528,6 @@ class SDTrainer(BaseSDTrainProcess):
print("loss is nan")
loss = torch.zeros_like(loss).requires_grad_(True)
with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
loss = loss * loss_multiplier.mean()
@@ -1543,8 +1544,27 @@ class SDTrainer(BaseSDTrainProcess):
loss.backward()
else:
self.scaler.scale(loss).backward()
return loss.detach()
# flush()
def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]):
if isinstance(batch, list):
batch_list = batch
else:
batch_list = [batch]
total_loss = None
for batch in batch_list:
self.optimizer.zero_grad(set_to_none=True)
loss = self.train_single_accumulation(batch)
if total_loss is None:
total_loss = loss
else:
total_loss += loss
if len(batch_list) > 1 and self.model_config.low_vram:
torch.cuda.empty_cache()
if not self.is_grad_accumulation_step:
# fix this for multi params
if self.train_config.optimizer != 'adafactor':