mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
fixed bug the kept learning rates the same
This commit is contained in:
@@ -1363,7 +1363,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# flush()
|
||||
|
||||
if not self.is_grad_accumulation_step:
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# fix this for multi params
|
||||
if isinstance(self.params[0], dict):
|
||||
for i in range(len(self.params)):
|
||||
torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
|
||||
@@ -1327,13 +1327,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
### HOOK ###
|
||||
params = self.hook_add_extra_train_params(params)
|
||||
self.params = []
|
||||
self.params = params
|
||||
# self.params = []
|
||||
|
||||
for param in params:
|
||||
if isinstance(param, dict):
|
||||
self.params += param['params']
|
||||
else:
|
||||
self.params.append(param)
|
||||
# for param in params:
|
||||
# if isinstance(param, dict):
|
||||
# self.params += param['params']
|
||||
# else:
|
||||
# self.params.append(param)
|
||||
|
||||
if self.train_config.start_step is not None:
|
||||
self.step_num = self.train_config.start_step
|
||||
@@ -1431,6 +1432,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
# self.step_num = 0
|
||||
|
||||
print(f"Compiling Model")
|
||||
torch.compile(self.sd.unet, dynamic=True)
|
||||
|
||||
###################################################################
|
||||
# TRAIN LOOP
|
||||
###################################################################
|
||||
@@ -1438,6 +1442,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
start_step_num = self.step_num
|
||||
did_first_flush = False
|
||||
for step in range(start_step_num, self.train_config.steps):
|
||||
self.timer.start('train_loop')
|
||||
if self.train_config.do_random_cfg:
|
||||
self.train_config.do_cfg = True
|
||||
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
|
||||
@@ -1506,7 +1511,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# flush()
|
||||
### HOOK ###
|
||||
self.timer.start('train_loop')
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
self.timer.stop('train_loop')
|
||||
if not did_first_flush:
|
||||
|
||||
Reference in New Issue
Block a user