Added doffusers schedulers

This commit is contained in:
Jaret Burkett
2023-10-29 12:39:50 -06:00
parent 298001439a
commit 48a9bac22d

View File

@@ -1,5 +1,6 @@
import torch
from typing import Optional
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup
def get_lr_scheduler(
@@ -11,7 +12,7 @@ def get_lr_scheduler(
if 'total_iters' in kwargs:
kwargs['T_max'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, **kwargs
optimizer, **kwargs
)
elif name == "cosine_with_restarts":
if 'total_iters' in kwargs:
@@ -34,7 +35,23 @@ def get_lr_scheduler(
return torch.optim.lr_scheduler.LinearLR(
optimizer, **kwargs
)
elif name == 'constant_with_warmup':
# see if num_warmup_steps is in kwargs
if 'num_warmup_steps' not in kwargs:
print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000")
kwargs['num_warmup_steps'] = 1000
del kwargs['total_iters']
return get_constant_schedule_with_warmup(optimizer, **kwargs)
else:
# try to use a diffusers scheduler
print(f"Trying to use diffusers scheduler {name}")
try:
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
return schedule_func(optimizer, **kwargs)
except Exception as e:
print(e)
pass
raise ValueError(
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
)