mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 17:29:27 +00:00
WIP on SAFE encoder. Work on fp16 training improvements. Various other tweaks and improvements
This commit is contained in:
@@ -31,6 +31,7 @@ from jobs.process import BaseSDTrainProcess
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
@@ -55,6 +56,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.negative_prompt_pool: Union[List[str], None] = None
|
||||
self.batch_negative_prompt: Union[List[str], None] = None
|
||||
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# patch the scaler to allow fp16 training
|
||||
org_unscale_grads = self.scaler._unscale_grads_
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
|
||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
@@ -1401,6 +1413,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
print("loss is nan")
|
||||
loss = torch.zeros_like(loss).requires_grad_(True)
|
||||
|
||||
|
||||
with self.timer('backward'):
|
||||
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
||||
loss = loss * loss_multiplier.mean()
|
||||
@@ -1410,7 +1423,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
# with fsdp_overlap_step_with_backward():
|
||||
loss.backward()
|
||||
if self.is_bfloat:
|
||||
loss.backward()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
# flush()
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
@@ -1423,8 +1439,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
if self.is_bfloat:
|
||||
self.optimizer.step()
|
||||
else:
|
||||
# apply gradients
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
# self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# gradient accumulation. Just a place for breakpoint
|
||||
|
||||
Reference in New Issue
Block a user