mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bugfixes for full finetuning at bf16
This commit is contained in:
@@ -60,14 +60,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# patch the scaler to allow fp16 training
|
||||
org_unscale_grads = self.scaler._unscale_grads_
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
|
||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||
|
||||
if self.train_config.dtype in ["fp16", "float16"]:
|
||||
# patch the scaler to allow fp16 training
|
||||
org_unscale_grads = self.scaler._unscale_grads_
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
@@ -1518,13 +1519,17 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# if self.is_bfloat:
|
||||
# loss.backward()
|
||||
# else:
|
||||
self.scaler.scale(loss).backward()
|
||||
if self.is_fine_tuning:
|
||||
loss.backward()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
# flush()
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
# fix this for multi params
|
||||
if self.train_config.optimizer != 'adafactor':
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
if not self.is_fine_tuning:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
if isinstance(self.params[0], dict):
|
||||
for i in range(len(self.params)):
|
||||
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||
@@ -1533,8 +1538,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
# self.optimizer.step()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
if self.is_fine_tuning:
|
||||
self.optimizer.step()
|
||||
else:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
if self.ema is not None:
|
||||
with self.timer('ema_update'):
|
||||
|
||||
Reference in New Issue
Block a user