From 48a9bac22df391dc4e072ec0003e34b2e9f29cce Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 29 Oct 2023 12:39:50 -0600 Subject: [PATCH] Added doffusers schedulers --- toolkit/scheduler.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/toolkit/scheduler.py b/toolkit/scheduler.py index a6b97d8d..95ae36d8 100644 --- a/toolkit/scheduler.py +++ b/toolkit/scheduler.py @@ -1,5 +1,6 @@ import torch from typing import Optional +from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup def get_lr_scheduler( @@ -11,7 +12,7 @@ def get_lr_scheduler( if 'total_iters' in kwargs: kwargs['T_max'] = kwargs.pop('total_iters') return torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, **kwargs + optimizer, **kwargs ) elif name == "cosine_with_restarts": if 'total_iters' in kwargs: @@ -34,7 +35,23 @@ def get_lr_scheduler( return torch.optim.lr_scheduler.LinearLR( optimizer, **kwargs ) + elif name == 'constant_with_warmup': + # see if num_warmup_steps is in kwargs + if 'num_warmup_steps' not in kwargs: + print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000") + kwargs['num_warmup_steps'] = 1000 + del kwargs['total_iters'] + return get_constant_schedule_with_warmup(optimizer, **kwargs) else: + # try to use a diffusers scheduler + print(f"Trying to use diffusers scheduler {name}") + try: + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **kwargs) + except Exception as e: + print(e) + pass raise ValueError( "Scheduler must be cosine, cosine_with_restarts, step, linear or constant" )