From f1cb87fe9e0e2971f0edf4694c70298a9ed5eecd Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 6 Mar 2024 09:23:32 -0700 Subject: [PATCH] fixed bug the kept learning rates the same --- extensions_built_in/sd_trainer/SDTrainer.py | 8 +++++++- jobs/process/BaseSDTrainProcess.py | 18 +++++++++++------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 9c397b4c..65424bc9 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1363,7 +1363,13 @@ class SDTrainer(BaseSDTrainProcess): # flush() if not self.is_grad_accumulation_step: - torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + # torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + # fix this for multi params + if isinstance(self.params[0], dict): + for i in range(len(self.params)): + torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # only step if we are not accumulating with self.timer('optimizer_step'): # apply gradients diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 472fada8..304d7bac 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1327,13 +1327,14 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() ### HOOK ### params = self.hook_add_extra_train_params(params) - self.params = [] + self.params = params + # self.params = [] - for param in params: - if isinstance(param, dict): - self.params += param['params'] - else: - self.params.append(param) + # for param in params: + # if isinstance(param, dict): + # self.params += param['params'] + # else: + # self.params.append(param) if self.train_config.start_step is not None: self.step_num = self.train_config.start_step @@ -1431,6 +1432,9 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() # self.step_num = 0 + print(f"Compiling Model") + torch.compile(self.sd.unet, dynamic=True) + ################################################################### # TRAIN LOOP ################################################################### @@ -1438,6 +1442,7 @@ class BaseSDTrainProcess(BaseTrainProcess): start_step_num = self.step_num did_first_flush = False for step in range(start_step_num, self.train_config.steps): + self.timer.start('train_loop') if self.train_config.do_random_cfg: self.train_config.do_cfg = True self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale) @@ -1506,7 +1511,6 @@ class BaseSDTrainProcess(BaseTrainProcess): # flush() ### HOOK ### - self.timer.start('train_loop') loss_dict = self.hook_train_loop(batch) self.timer.stop('train_loop') if not did_first_flush: