mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Allow special args for schedulers
This commit is contained in:
@@ -553,13 +553,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
optimizer_params=self.train_config.optimizer_params)
|
||||
self.optimizer = optimizer
|
||||
|
||||
lr_scheduler_params = self.train_config.lr_scheduler_params
|
||||
|
||||
# make sure it had bare minimum
|
||||
if 'max_iterations' not in lr_scheduler_params:
|
||||
lr_scheduler_params['total_iters'] = self.train_config.steps
|
||||
|
||||
lr_scheduler = get_lr_scheduler(
|
||||
self.train_config.lr_scheduler,
|
||||
optimizer,
|
||||
max_iterations=self.train_config.steps,
|
||||
lr_min=self.train_config.lr / 100,
|
||||
**lr_scheduler_params
|
||||
)
|
||||
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
### HOOK ###
|
||||
@@ -601,6 +605,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# zero any gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
self.lr_scheduler.step(self.step_num)
|
||||
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user