mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added doffusers schedulers
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup
|
||||
|
||||
|
||||
def get_lr_scheduler(
|
||||
@@ -11,7 +12,7 @@ def get_lr_scheduler(
|
||||
if 'total_iters' in kwargs:
|
||||
kwargs['T_max'] = kwargs.pop('total_iters')
|
||||
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, **kwargs
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == "cosine_with_restarts":
|
||||
if 'total_iters' in kwargs:
|
||||
@@ -34,7 +35,23 @@ def get_lr_scheduler(
|
||||
return torch.optim.lr_scheduler.LinearLR(
|
||||
optimizer, **kwargs
|
||||
)
|
||||
elif name == 'constant_with_warmup':
|
||||
# see if num_warmup_steps is in kwargs
|
||||
if 'num_warmup_steps' not in kwargs:
|
||||
print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000")
|
||||
kwargs['num_warmup_steps'] = 1000
|
||||
del kwargs['total_iters']
|
||||
return get_constant_schedule_with_warmup(optimizer, **kwargs)
|
||||
else:
|
||||
# try to use a diffusers scheduler
|
||||
print(f"Trying to use diffusers scheduler {name}")
|
||||
try:
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
return schedule_func(optimizer, **kwargs)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
raise ValueError(
|
||||
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user