Bugfixes for full finetuning at bf16

This commit is contained in:
Jaret Burkett
2024-08-22 05:15:33 -06:00
parent 6a754b2710
commit e07a98a50c

View File

@@ -60,14 +60,15 @@ class SDTrainer(BaseSDTrainProcess):
self.scaler = torch.cuda.amp.GradScaler()
# patch the scaler to allow fp16 training
org_unscale_grads = self.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
self.scaler._unscale_grads_ = _unscale_grads_replacer
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
if self.train_config.dtype in ["fp16", "float16"]:
# patch the scaler to allow fp16 training
org_unscale_grads = self.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
self.scaler._unscale_grads_ = _unscale_grads_replacer
def before_model_load(self):
pass
@@ -1518,13 +1519,17 @@ class SDTrainer(BaseSDTrainProcess):
# if self.is_bfloat:
# loss.backward()
# else:
self.scaler.scale(loss).backward()
if self.is_fine_tuning:
loss.backward()
else:
self.scaler.scale(loss).backward()
# flush()
if not self.is_grad_accumulation_step:
# fix this for multi params
if self.train_config.optimizer != 'adafactor':
self.scaler.unscale_(self.optimizer)
if not self.is_fine_tuning:
self.scaler.unscale_(self.optimizer)
if isinstance(self.params[0], dict):
for i in range(len(self.params)):
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
@@ -1533,8 +1538,12 @@ class SDTrainer(BaseSDTrainProcess):
# only step if we are not accumulating
with self.timer('optimizer_step'):
# self.optimizer.step()
self.scaler.step(self.optimizer)
self.scaler.update()
if self.is_fine_tuning:
self.optimizer.step()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
if self.ema is not None:
with self.timer('ema_update'):