diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 7af2f0b..9c397b4 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -83,6 +83,8 @@ class SDTrainer(BaseSDTrainProcess): self.taesd.requires_grad_(False) def hook_before_train_loop(self): + if self.train_config.do_prior_divergence: + self.do_prior_prediction = True # move vae to device if we did not cache latents if not self.is_latents_cached: self.sd.vae.eval() @@ -290,7 +292,7 @@ class SDTrainer(BaseSDTrainProcess): # target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier) # set masked multiplier to 1.0 so we dont double apply it # mask_multiplier = 1.0 - elif prior_pred is not None: + elif prior_pred is not None and not self.train_config.do_prior_divergence: assert not self.train_config.train_turbo # matching adapter prediction target = prior_pred @@ -347,6 +349,9 @@ class SDTrainer(BaseSDTrainProcess): else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + if self.train_config.do_prior_divergence and prior_pred is not None: + loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) + # multiply by our mask loss = loss * mask_multiplier diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 3f87114..472fada 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -688,6 +688,16 @@ class BaseSDTrainProcess(BaseTrainProcess): trigger=self.trigger_word, add_if_not_present=not is_reg, ) + + if not is_reg and self.train_config.prompt_saturation_chance > 0.0: + # do random prompt saturation by expanding the prompt to hit at least 77 tokens + if random.random() < self.train_config.prompt_saturation_chance: + est_num_tokens = len(prompt.split(' ')) + if est_num_tokens < 77: + num_repeats = int(77 / est_num_tokens) + 1 + prompt = ', '.join([prompt] * num_repeats) + + conditioned_prompts.append(prompt) with self.timer('prepare_latents'): diff --git a/repositories/ipadapter b/repositories/ipadapter index f71c943..5a18b1f 160000 --- a/repositories/ipadapter +++ b/repositories/ipadapter @@ -1 +1 @@ -Subproject commit f71c943b7e1d3ffccae8e4f04b9adebac037e19f +Subproject commit 5a18b1f3660acaf8bee8250692d6fb3548a19b14 diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ebcdb96..1bf89dc 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -308,6 +308,13 @@ class TrainConfig: # scale the prediction by this. Increase for more detail, decrease for less self.pred_scaler = kwargs.get('pred_scaler', 1.0) + # repeats the prompt a few times to saturate the encoder + self.prompt_saturation_chance = kwargs.get('prompt_saturation_chance', 0.0) + + # applies negative loss on the prior to encourage network to diverge from it + self.do_prior_divergence = kwargs.get('do_prior_divergence', False) + + class ModelConfig: def __init__(self, **kwargs): diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index a0c056c..2698d15 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -49,7 +49,7 @@ class InstantLoRAMidModule(torch.nn.Module): print(scaler.shape) raise e # apply tanh to limit values to -1 to 1 - scaler = torch.tanh(scaler) + # scaler = torch.tanh(scaler) return x * scaler diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index 2332b16..5d2ddde 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -65,7 +65,16 @@ def get_optimizer( elif lower_type == 'adagrad': optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params) elif lower_type == 'adafactor': + # hack in stochastic rounding + if 'relative_step' not in optimizer_params: + optimizer_params['relative_step'] = False + if 'scale_parameter' not in optimizer_params: + optimizer_params['scale_parameter'] = True + if 'warmup_init' not in optimizer_params: + optimizer_params['warmup_init'] = False optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params) + from toolkit.util.adafactor_stochastic_rounding import step_adafactor + optimizer.step = step_adafactor.__get__(optimizer, Adafactor) else: raise ValueError(f'Unknown optimizer type {optimizer_type}') return optimizer diff --git a/toolkit/util/adafactor_stochastic_rounding.py b/toolkit/util/adafactor_stochastic_rounding.py new file mode 100644 index 0000000..0e65a32 --- /dev/null +++ b/toolkit/util/adafactor_stochastic_rounding.py @@ -0,0 +1,119 @@ +# ref https://github.com/Nerogar/OneTrainer/compare/master...stochastic_rounding +import math +import torch +from torch import Tensor + + +def copy_stochastic_(target: Tensor, source: Tensor): + # create a random 16 bit integer + result = torch.randint_like( + source, + dtype=torch.int32, + low=0, + high=(1 << 16), + ) + + # add the random number to the lower 16 bit of the mantissa + result.add_(source.view(dtype=torch.int32)) + + # mask off the lower 16 bit of the mantissa + result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 + + # copy the higher 16 bit into the target tensor + target.copy_(result.view(dtype=torch.float32)) + + +@torch.no_grad() +def step_adafactor(self, closure=None): + """ + Performs a single optimization step + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad ** 2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype == torch.bfloat16: + copy_stochastic_(p, p_data_fp32) + elif p.dtype == torch.float16: + p.copy_(p_data_fp32) + + return loss