mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
8 bit training working on flux
This commit is contained in:
@@ -1538,22 +1538,19 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# flush()
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
# torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# fix this for multi params
|
||||
if isinstance(self.params[0], dict):
|
||||
for i in range(len(self.params)):
|
||||
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
if self.train_config.optimizer != 'adafactor':
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
if isinstance(self.params[0], dict):
|
||||
for i in range(len(self.params)):
|
||||
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
if self.is_bfloat:
|
||||
self.optimizer.step()
|
||||
else:
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
# self.scaler.update()
|
||||
# self.optimizer.step()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
if self.ema is not None:
|
||||
with self.timer('ema_update'):
|
||||
|
||||
Reference in New Issue
Block a user