WIP on SAFE encoder. Work on fp16 training improvements. Various other tweaks and improvements

This commit is contained in:
Jaret Burkett
2024-05-27 10:50:24 -06:00
parent 68b7e159bc
commit 833c833f28
9 changed files with 127 additions and 49 deletions

View File

@@ -31,6 +31,7 @@ from jobs.process import BaseSDTrainProcess
from torchvision import transforms
def flush():
torch.cuda.empty_cache()
gc.collect()
@@ -55,6 +56,17 @@ class SDTrainer(BaseSDTrainProcess):
self.negative_prompt_pool: Union[List[str], None] = None
self.batch_negative_prompt: Union[List[str], None] = None
self.scaler = torch.cuda.amp.GradScaler()
# patch the scaler to allow fp16 training
org_unscale_grads = self.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
self.scaler._unscale_grads_ = _unscale_grads_replacer
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
def before_model_load(self):
pass
@@ -1401,6 +1413,7 @@ class SDTrainer(BaseSDTrainProcess):
print("loss is nan")
loss = torch.zeros_like(loss).requires_grad_(True)
with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
loss = loss * loss_multiplier.mean()
@@ -1410,7 +1423,10 @@ class SDTrainer(BaseSDTrainProcess):
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
# with fsdp_overlap_step_with_backward():
loss.backward()
if self.is_bfloat:
loss.backward()
else:
self.scaler.scale(loss).backward()
# flush()
if not self.is_grad_accumulation_step:
@@ -1423,8 +1439,13 @@ class SDTrainer(BaseSDTrainProcess):
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# only step if we are not accumulating
with self.timer('optimizer_step'):
# apply gradients
self.optimizer.step()
if self.is_bfloat:
self.optimizer.step()
else:
# apply gradients
self.scaler.step(self.optimizer)
self.scaler.update()
# self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
else:
# gradient accumulation. Just a place for breakpoint