mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
* issue #76, load_checkpoint_and_dispatch() 'force_hooks' https://github.com/ostris/ai-toolkit/issues/76 * RunPod cloud config https://github.com/ostris/ai-toolkit/issues/90 * change 2x A40 to 1x A40 and price per hour referring to https://github.com/ostris/ai-toolkit/issues/90#issuecomment-2294894929 * include missed FLUX.1-schnell setup guide in last commit * huggingface-cli login required auth * #92 peft, #114 colab, schnell training in colab * modal cloud - run_modal.py and .yaml configs * run_modal.py mount path example * modal_examples renamed to modal * Training in Modal README.md setup guide * rename run command in title for consistency
58 lines
1.9 KiB
Python
58 lines
1.9 KiB
Python
import torch
|
|
from typing import Optional
|
|
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup
|
|
|
|
|
|
def get_lr_scheduler(
|
|
name: Optional[str],
|
|
optimizer: torch.optim.Optimizer,
|
|
**kwargs,
|
|
):
|
|
if name == "cosine":
|
|
if 'total_iters' in kwargs:
|
|
kwargs['T_max'] = kwargs.pop('total_iters')
|
|
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
optimizer, **kwargs
|
|
)
|
|
elif name == "cosine_with_restarts":
|
|
if 'total_iters' in kwargs:
|
|
kwargs['T_0'] = kwargs.pop('total_iters')
|
|
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
|
optimizer, **kwargs
|
|
)
|
|
elif name == "step":
|
|
|
|
return torch.optim.lr_scheduler.StepLR(
|
|
optimizer, **kwargs
|
|
)
|
|
elif name == "constant":
|
|
if 'factor' not in kwargs:
|
|
kwargs['factor'] = 1.0
|
|
|
|
return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
|
|
elif name == "linear":
|
|
|
|
return torch.optim.lr_scheduler.LinearLR(
|
|
optimizer, **kwargs
|
|
)
|
|
elif name == 'constant_with_warmup':
|
|
# see if num_warmup_steps is in kwargs
|
|
if 'num_warmup_steps' not in kwargs:
|
|
print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000")
|
|
kwargs['num_warmup_steps'] = 1000
|
|
del kwargs['total_iters']
|
|
return get_constant_schedule_with_warmup(optimizer, **kwargs)
|
|
else:
|
|
# try to use a diffusers scheduler
|
|
print(f"Trying to use diffusers scheduler {name}")
|
|
try:
|
|
name = SchedulerType(name)
|
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
|
return schedule_func(optimizer, **kwargs)
|
|
except Exception as e:
|
|
print(e)
|
|
pass
|
|
raise ValueError(
|
|
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
|
|
)
|