Allow special args for schedulers

This commit is contained in:
Jaret Burkett
2023-09-03 20:38:44 -06:00
parent 7cd6945082
commit 22ed539321
3 changed files with 25 additions and 11 deletions

View File

@@ -553,13 +553,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
optimizer_params=self.train_config.optimizer_params)
self.optimizer = optimizer
lr_scheduler_params = self.train_config.lr_scheduler_params
# make sure it had bare minimum
if 'max_iterations' not in lr_scheduler_params:
lr_scheduler_params['total_iters'] = self.train_config.steps
lr_scheduler = get_lr_scheduler(
self.train_config.lr_scheduler,
optimizer,
max_iterations=self.train_config.steps,
lr_min=self.train_config.lr / 100,
**lr_scheduler_params
)
self.lr_scheduler = lr_scheduler
### HOOK ###
@@ -601,6 +605,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# zero any gradients
optimizer.zero_grad()
self.lr_scheduler.step(self.step_num)
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
with torch.no_grad():