mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Bugfixes for full finetuning at bf16
This commit is contained in:
@@ -60,14 +60,15 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
# patch the scaler to allow fp16 training
|
|
||||||
org_unscale_grads = self.scaler._unscale_grads_
|
|
||||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
|
||||||
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
|
||||||
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
|
||||||
|
|
||||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||||
|
|
||||||
|
if self.train_config.dtype in ["fp16", "float16"]:
|
||||||
|
# patch the scaler to allow fp16 training
|
||||||
|
org_unscale_grads = self.scaler._unscale_grads_
|
||||||
|
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||||
|
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
||||||
|
self.scaler._unscale_grads_ = _unscale_grads_replacer
|
||||||
|
|
||||||
|
|
||||||
def before_model_load(self):
|
def before_model_load(self):
|
||||||
pass
|
pass
|
||||||
@@ -1518,13 +1519,17 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# if self.is_bfloat:
|
# if self.is_bfloat:
|
||||||
# loss.backward()
|
# loss.backward()
|
||||||
# else:
|
# else:
|
||||||
self.scaler.scale(loss).backward()
|
if self.is_fine_tuning:
|
||||||
|
loss.backward()
|
||||||
|
else:
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
# flush()
|
# flush()
|
||||||
|
|
||||||
if not self.is_grad_accumulation_step:
|
if not self.is_grad_accumulation_step:
|
||||||
# fix this for multi params
|
# fix this for multi params
|
||||||
if self.train_config.optimizer != 'adafactor':
|
if self.train_config.optimizer != 'adafactor':
|
||||||
self.scaler.unscale_(self.optimizer)
|
if not self.is_fine_tuning:
|
||||||
|
self.scaler.unscale_(self.optimizer)
|
||||||
if isinstance(self.params[0], dict):
|
if isinstance(self.params[0], dict):
|
||||||
for i in range(len(self.params)):
|
for i in range(len(self.params)):
|
||||||
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||||
@@ -1533,8 +1538,12 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# only step if we are not accumulating
|
# only step if we are not accumulating
|
||||||
with self.timer('optimizer_step'):
|
with self.timer('optimizer_step'):
|
||||||
# self.optimizer.step()
|
# self.optimizer.step()
|
||||||
self.scaler.step(self.optimizer)
|
if self.is_fine_tuning:
|
||||||
self.scaler.update()
|
self.optimizer.step()
|
||||||
|
else:
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
|
||||||
self.optimizer.zero_grad(set_to_none=True)
|
self.optimizer.zero_grad(set_to_none=True)
|
||||||
if self.ema is not None:
|
if self.ema is not None:
|
||||||
with self.timer('ema_update'):
|
with self.timer('ema_update'):
|
||||||
|
|||||||
Reference in New Issue
Block a user