mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added stochastic rounding to adafactor. ILora adjustments
This commit is contained in:
@@ -83,6 +83,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.taesd.requires_grad_(False)
|
self.taesd.requires_grad_(False)
|
||||||
|
|
||||||
def hook_before_train_loop(self):
|
def hook_before_train_loop(self):
|
||||||
|
if self.train_config.do_prior_divergence:
|
||||||
|
self.do_prior_prediction = True
|
||||||
# move vae to device if we did not cache latents
|
# move vae to device if we did not cache latents
|
||||||
if not self.is_latents_cached:
|
if not self.is_latents_cached:
|
||||||
self.sd.vae.eval()
|
self.sd.vae.eval()
|
||||||
@@ -290,7 +292,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
|
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
|
||||||
# set masked multiplier to 1.0 so we dont double apply it
|
# set masked multiplier to 1.0 so we dont double apply it
|
||||||
# mask_multiplier = 1.0
|
# mask_multiplier = 1.0
|
||||||
elif prior_pred is not None:
|
elif prior_pred is not None and not self.train_config.do_prior_divergence:
|
||||||
assert not self.train_config.train_turbo
|
assert not self.train_config.train_turbo
|
||||||
# matching adapter prediction
|
# matching adapter prediction
|
||||||
target = prior_pred
|
target = prior_pred
|
||||||
@@ -347,6 +349,9 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
else:
|
else:
|
||||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||||
|
|
||||||
|
if self.train_config.do_prior_divergence and prior_pred is not None:
|
||||||
|
loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0)
|
||||||
|
|
||||||
# multiply by our mask
|
# multiply by our mask
|
||||||
loss = loss * mask_multiplier
|
loss = loss * mask_multiplier
|
||||||
|
|
||||||
|
|||||||
@@ -688,6 +688,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
trigger=self.trigger_word,
|
trigger=self.trigger_word,
|
||||||
add_if_not_present=not is_reg,
|
add_if_not_present=not is_reg,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not is_reg and self.train_config.prompt_saturation_chance > 0.0:
|
||||||
|
# do random prompt saturation by expanding the prompt to hit at least 77 tokens
|
||||||
|
if random.random() < self.train_config.prompt_saturation_chance:
|
||||||
|
est_num_tokens = len(prompt.split(' '))
|
||||||
|
if est_num_tokens < 77:
|
||||||
|
num_repeats = int(77 / est_num_tokens) + 1
|
||||||
|
prompt = ', '.join([prompt] * num_repeats)
|
||||||
|
|
||||||
|
|
||||||
conditioned_prompts.append(prompt)
|
conditioned_prompts.append(prompt)
|
||||||
|
|
||||||
with self.timer('prepare_latents'):
|
with self.timer('prepare_latents'):
|
||||||
|
|||||||
Submodule repositories/ipadapter updated: f71c943b7e...5a18b1f366
@@ -308,6 +308,13 @@ class TrainConfig:
|
|||||||
# scale the prediction by this. Increase for more detail, decrease for less
|
# scale the prediction by this. Increase for more detail, decrease for less
|
||||||
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
self.pred_scaler = kwargs.get('pred_scaler', 1.0)
|
||||||
|
|
||||||
|
# repeats the prompt a few times to saturate the encoder
|
||||||
|
self.prompt_saturation_chance = kwargs.get('prompt_saturation_chance', 0.0)
|
||||||
|
|
||||||
|
# applies negative loss on the prior to encourage network to diverge from it
|
||||||
|
self.do_prior_divergence = kwargs.get('do_prior_divergence', False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ class InstantLoRAMidModule(torch.nn.Module):
|
|||||||
print(scaler.shape)
|
print(scaler.shape)
|
||||||
raise e
|
raise e
|
||||||
# apply tanh to limit values to -1 to 1
|
# apply tanh to limit values to -1 to 1
|
||||||
scaler = torch.tanh(scaler)
|
# scaler = torch.tanh(scaler)
|
||||||
return x * scaler
|
return x * scaler
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,16 @@ def get_optimizer(
|
|||||||
elif lower_type == 'adagrad':
|
elif lower_type == 'adagrad':
|
||||||
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
|
optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
|
||||||
elif lower_type == 'adafactor':
|
elif lower_type == 'adafactor':
|
||||||
|
# hack in stochastic rounding
|
||||||
|
if 'relative_step' not in optimizer_params:
|
||||||
|
optimizer_params['relative_step'] = False
|
||||||
|
if 'scale_parameter' not in optimizer_params:
|
||||||
|
optimizer_params['scale_parameter'] = True
|
||||||
|
if 'warmup_init' not in optimizer_params:
|
||||||
|
optimizer_params['warmup_init'] = False
|
||||||
optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params)
|
optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params)
|
||||||
|
from toolkit.util.adafactor_stochastic_rounding import step_adafactor
|
||||||
|
optimizer.step = step_adafactor.__get__(optimizer, Adafactor)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|||||||
119
toolkit/util/adafactor_stochastic_rounding.py
Normal file
119
toolkit/util/adafactor_stochastic_rounding.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
# ref https://github.com/Nerogar/OneTrainer/compare/master...stochastic_rounding
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def copy_stochastic_(target: Tensor, source: Tensor):
|
||||||
|
# create a random 16 bit integer
|
||||||
|
result = torch.randint_like(
|
||||||
|
source,
|
||||||
|
dtype=torch.int32,
|
||||||
|
low=0,
|
||||||
|
high=(1 << 16),
|
||||||
|
)
|
||||||
|
|
||||||
|
# add the random number to the lower 16 bit of the mantissa
|
||||||
|
result.add_(source.view(dtype=torch.int32))
|
||||||
|
|
||||||
|
# mask off the lower 16 bit of the mantissa
|
||||||
|
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
|
||||||
|
|
||||||
|
# copy the higher 16 bit into the target tensor
|
||||||
|
target.copy_(result.view(dtype=torch.float32))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step_adafactor(self, closure=None):
|
||||||
|
"""
|
||||||
|
Performs a single optimization step
|
||||||
|
Arguments:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad
|
||||||
|
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||||
|
grad = grad.float()
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError("Adafactor does not support sparse gradients.")
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
grad_shape = grad.shape
|
||||||
|
|
||||||
|
factored, use_first_moment = self._get_options(group, grad_shape)
|
||||||
|
# State Initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
state["step"] = 0
|
||||||
|
|
||||||
|
if use_first_moment:
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state["exp_avg"] = torch.zeros_like(grad)
|
||||||
|
if factored:
|
||||||
|
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
||||||
|
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||||
|
else:
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||||
|
|
||||||
|
state["RMS"] = 0
|
||||||
|
else:
|
||||||
|
if use_first_moment:
|
||||||
|
state["exp_avg"] = state["exp_avg"].to(grad)
|
||||||
|
if factored:
|
||||||
|
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
||||||
|
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
||||||
|
else:
|
||||||
|
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||||
|
|
||||||
|
p_data_fp32 = p
|
||||||
|
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||||
|
p_data_fp32 = p_data_fp32.float()
|
||||||
|
|
||||||
|
state["step"] += 1
|
||||||
|
state["RMS"] = self._rms(p_data_fp32)
|
||||||
|
lr = self._get_lr(group, state)
|
||||||
|
|
||||||
|
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
||||||
|
update = (grad ** 2) + group["eps"][0]
|
||||||
|
if factored:
|
||||||
|
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||||
|
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||||
|
|
||||||
|
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
||||||
|
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
||||||
|
|
||||||
|
# Approximation of exponential moving average of square of gradient
|
||||||
|
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||||||
|
update.mul_(grad)
|
||||||
|
else:
|
||||||
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
|
|
||||||
|
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
||||||
|
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||||
|
|
||||||
|
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
||||||
|
update.mul_(lr)
|
||||||
|
|
||||||
|
if use_first_moment:
|
||||||
|
exp_avg = state["exp_avg"]
|
||||||
|
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
||||||
|
update = exp_avg
|
||||||
|
|
||||||
|
if group["weight_decay"] != 0:
|
||||||
|
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
||||||
|
|
||||||
|
p_data_fp32.add_(-update)
|
||||||
|
|
||||||
|
if p.dtype == torch.bfloat16:
|
||||||
|
copy_stochastic_(p, p_data_fp32)
|
||||||
|
elif p.dtype == torch.float16:
|
||||||
|
p.copy_(p_data_fp32)
|
||||||
|
|
||||||
|
return loss
|
||||||
Reference in New Issue
Block a user