8 bit training working on flux

This commit is contained in:
Jaret Burkett
2024-08-06 11:53:27 -06:00
parent 272c8608c2
commit c2424087d6
7 changed files with 82 additions and 31 deletions

View File

@@ -1538,22 +1538,19 @@ class SDTrainer(BaseSDTrainProcess):
# flush()
if not self.is_grad_accumulation_step:
# torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# fix this for multi params
if isinstance(self.params[0], dict):
for i in range(len(self.params)):
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
if self.train_config.optimizer != 'adafactor':
self.scaler.unscale_(self.optimizer)
if isinstance(self.params[0], dict):
for i in range(len(self.params)):
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# only step if we are not accumulating
with self.timer('optimizer_step'):
if self.is_bfloat:
self.optimizer.step()
else:
# apply gradients
self.optimizer.step()
# self.scaler.update()
# self.optimizer.step()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
if self.ema is not None:
with self.timer('ema_update'):