mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added proper grad accumulation
This commit is contained in:
@@ -910,7 +910,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||
self.timer.start('preprocess_batch')
|
||||
batch = self.preprocess_batch(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -1243,7 +1243,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs = {}
|
||||
|
||||
if has_adapter_img:
|
||||
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
|
||||
if (self.adapter and isinstance(self.adapter, T2IAdapter)) or (
|
||||
self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
@@ -1283,7 +1284,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.train_config.do_cfg:
|
||||
embeds = [
|
||||
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0])
|
||||
load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in
|
||||
range(noisy_latents.shape[0])
|
||||
]
|
||||
unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache(
|
||||
embeds,
|
||||
@@ -1424,7 +1426,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if prior_pred is not None:
|
||||
prior_pred = prior_pred.detach()
|
||||
|
||||
|
||||
# do the custom adapter after the prior prediction
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
||||
quad_count = random.randint(1, 4)
|
||||
@@ -1450,10 +1451,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.adapter.add_extra_values(batch.extra_values.detach())
|
||||
|
||||
if self.train_config.do_cfg:
|
||||
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True)
|
||||
self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()),
|
||||
is_unconditional=True)
|
||||
|
||||
if has_adapter_img:
|
||||
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
|
||||
if (self.adapter and isinstance(self.adapter, ControlNetModel)) or (
|
||||
self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)):
|
||||
if self.train_config.do_cfg:
|
||||
raise ValueError("ControlNetModel is not supported with CFG")
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
@@ -1478,7 +1481,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_res_samples
|
||||
pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample
|
||||
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
if batch.unconditional_latents is not None or self.do_guided_loss:
|
||||
@@ -1526,7 +1528,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
print("loss is nan")
|
||||
loss = torch.zeros_like(loss).requires_grad_(True)
|
||||
|
||||
|
||||
with self.timer('backward'):
|
||||
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
||||
loss = loss * loss_multiplier.mean()
|
||||
@@ -1543,8 +1544,27 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss.backward()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
return loss.detach()
|
||||
# flush()
|
||||
|
||||
def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]):
|
||||
if isinstance(batch, list):
|
||||
batch_list = batch
|
||||
else:
|
||||
batch_list = [batch]
|
||||
total_loss = None
|
||||
for batch in batch_list:
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
loss = self.train_single_accumulation(batch)
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
total_loss += loss
|
||||
if len(batch_list) > 1 and self.model_config.low_vram:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
# fix this for multi params
|
||||
if self.train_config.optimizer != 'adafactor':
|
||||
|
||||
Reference in New Issue
Block a user