Files
ai-toolkit/toolkit/scheduler.py
martintomov 34db804c76 Modal cloud training support, fixed typo in toolkit/scheduler.py, Schnell training support for Colab, issue #92 , issue #114 (#115)
* issue #76, load_checkpoint_and_dispatch() 'force_hooks'

https://github.com/ostris/ai-toolkit/issues/76

* RunPod cloud config

https://github.com/ostris/ai-toolkit/issues/90

* change 2x A40 to 1x A40 and price per hour

referring to https://github.com/ostris/ai-toolkit/issues/90#issuecomment-2294894929

* include missed FLUX.1-schnell setup guide in last commit

* huggingface-cli login required auth

* #92 peft, #114 colab, schnell training in colab

* modal cloud - run_modal.py and .yaml configs

* run_modal.py mount path example

* modal_examples renamed to modal

* Training in Modal README.md setup guide

* rename run command in title for consistency
2024-08-22 21:25:44 -06:00

58 lines
1.9 KiB
Python

import torch
from typing import Optional
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup
def get_lr_scheduler(
name: Optional[str],
optimizer: torch.optim.Optimizer,
**kwargs,
):
if name == "cosine":
if 'total_iters' in kwargs:
kwargs['T_max'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, **kwargs
)
elif name == "cosine_with_restarts":
if 'total_iters' in kwargs:
kwargs['T_0'] = kwargs.pop('total_iters')
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, **kwargs
)
elif name == "step":
return torch.optim.lr_scheduler.StepLR(
optimizer, **kwargs
)
elif name == "constant":
if 'factor' not in kwargs:
kwargs['factor'] = 1.0
return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
elif name == "linear":
return torch.optim.lr_scheduler.LinearLR(
optimizer, **kwargs
)
elif name == 'constant_with_warmup':
# see if num_warmup_steps is in kwargs
if 'num_warmup_steps' not in kwargs:
print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000")
kwargs['num_warmup_steps'] = 1000
del kwargs['total_iters']
return get_constant_schedule_with_warmup(optimizer, **kwargs)
else:
# try to use a diffusers scheduler
print(f"Trying to use diffusers scheduler {name}")
try:
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
return schedule_func(optimizer, **kwargs)
except Exception as e:
print(e)
pass
raise ValueError(
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
)