diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index fd0fe800..da58b159 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -60,14 +60,15 @@ class SDTrainer(BaseSDTrainProcess): self.scaler = torch.cuda.amp.GradScaler() - # patch the scaler to allow fp16 training - org_unscale_grads = self.scaler._unscale_grads_ - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - self.scaler._unscale_grads_ = _unscale_grads_replacer - self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" + if self.train_config.dtype in ["fp16", "float16"]: + # patch the scaler to allow fp16 training + org_unscale_grads = self.scaler._unscale_grads_ + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + self.scaler._unscale_grads_ = _unscale_grads_replacer + def before_model_load(self): pass @@ -1518,13 +1519,17 @@ class SDTrainer(BaseSDTrainProcess): # if self.is_bfloat: # loss.backward() # else: - self.scaler.scale(loss).backward() + if self.is_fine_tuning: + loss.backward() + else: + self.scaler.scale(loss).backward() # flush() if not self.is_grad_accumulation_step: # fix this for multi params if self.train_config.optimizer != 'adafactor': - self.scaler.unscale_(self.optimizer) + if not self.is_fine_tuning: + self.scaler.unscale_(self.optimizer) if isinstance(self.params[0], dict): for i in range(len(self.params)): torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) @@ -1533,8 +1538,12 @@ class SDTrainer(BaseSDTrainProcess): # only step if we are not accumulating with self.timer('optimizer_step'): # self.optimizer.step() - self.scaler.step(self.optimizer) - self.scaler.update() + if self.is_fine_tuning: + self.optimizer.step() + else: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad(set_to_none=True) if self.ema is not None: with self.timer('ema_update'):