fixed bug the kept learning rates the same

This commit is contained in:
Jaret Burkett
2024-03-06 09:23:32 -07:00
parent b01e8d889a
commit f1cb87fe9e
2 changed files with 18 additions and 8 deletions

View File

@@ -1327,13 +1327,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
### HOOK ###
params = self.hook_add_extra_train_params(params)
self.params = []
self.params = params
# self.params = []
for param in params:
if isinstance(param, dict):
self.params += param['params']
else:
self.params.append(param)
# for param in params:
# if isinstance(param, dict):
# self.params += param['params']
# else:
# self.params.append(param)
if self.train_config.start_step is not None:
self.step_num = self.train_config.start_step
@@ -1431,6 +1432,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
# self.step_num = 0
print(f"Compiling Model")
torch.compile(self.sd.unet, dynamic=True)
###################################################################
# TRAIN LOOP
###################################################################
@@ -1438,6 +1442,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
start_step_num = self.step_num
did_first_flush = False
for step in range(start_step_num, self.train_config.steps):
self.timer.start('train_loop')
if self.train_config.do_random_cfg:
self.train_config.do_cfg = True
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
@@ -1506,7 +1511,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
# flush()
### HOOK ###
self.timer.start('train_loop')
loss_dict = self.hook_train_loop(batch)
self.timer.stop('train_loop')
if not did_first_flush: