mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-03 17:49:49 +00:00
improved correction of pred norm by targeting the prior
This commit is contained in:
@@ -227,37 +227,30 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.train_config.correct_pred_norm and not is_reg:
|
||||
with torch.no_grad():
|
||||
# this only works if doing a prior pred
|
||||
if prior_pred is not None:
|
||||
prior_mean = prior_pred.mean([2,3], keepdim=True)
|
||||
prior_std = prior_pred.std([2,3], keepdim=True)
|
||||
noise_mean = noise_pred.mean([2,3], keepdim=True)
|
||||
noise_std = noise_pred.std([2,3], keepdim=True)
|
||||
|
||||
# adjust the noise target in the opposite direction of the noise pred mean and std offset
|
||||
# this will apply additional force the model to correct itself to match the norm of the noise
|
||||
noise_pred_mean, noise_pred_std = get_mean_std(noise_pred)
|
||||
noise_mean, noise_std = get_mean_std(noise)
|
||||
mean_adjust = prior_mean - noise_mean
|
||||
std_adjust = prior_std - noise_std
|
||||
|
||||
# apply the inverse offset of the mean and std to the noise
|
||||
noise_additional_mean = noise_mean - noise_pred_mean
|
||||
noise_additional_std = noise_std - noise_pred_std
|
||||
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
|
||||
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
|
||||
|
||||
# adjust for multiplier
|
||||
noise_additional_mean = noise_additional_mean * self.train_config.correct_pred_norm_multiplier
|
||||
noise_additional_std = noise_additional_std * self.train_config.correct_pred_norm_multiplier
|
||||
target_mean = noise_mean + mean_adjust
|
||||
target_std = noise_std + std_adjust
|
||||
|
||||
noise_target_std = noise_std + noise_additional_std
|
||||
noise_target_mean = noise_mean + noise_additional_mean
|
||||
eps = 1e-5
|
||||
|
||||
|
||||
noise_pred_target_std = noise_pred_std - noise_additional_std
|
||||
noise_pred_target_mean = noise_pred_mean - noise_additional_mean
|
||||
noise_pred_target_std = noise_pred_target_std.detach()
|
||||
noise_pred_target_mean = noise_pred_target_mean.detach()
|
||||
|
||||
# match the noise to the target
|
||||
noise = (noise - noise_mean) / noise_std
|
||||
noise = noise * noise_target_std + noise_target_mean
|
||||
noise = noise.detach()
|
||||
|
||||
# meatch the noise pred to the target
|
||||
# noise_pred = (noise_pred - noise_pred_mean) / noise_pred_std
|
||||
# noise_pred = noise_pred * noise_pred_target_std + noise_pred_target_mean
|
||||
# adjust the noise target to match the current knowledge of the model
|
||||
# noise_mean, noise_std = get_mean_std(noise)
|
||||
# match the noise to the prior
|
||||
noise = (noise - noise_mean) / (noise_std + eps)
|
||||
noise = noise * (target_std + eps) + target_mean
|
||||
noise = noise.detach()
|
||||
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
||||
assert not self.train_config.train_turbo
|
||||
|
||||
@@ -306,6 +306,7 @@ class ModelConfig:
|
||||
self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None)
|
||||
self._original_refiner_name_or_path = self.refiner_name_or_path
|
||||
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
|
||||
self.lora_path = kwargs.get('lora_path', None)
|
||||
|
||||
# only for SDXL models for now
|
||||
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
|
||||
|
||||
@@ -281,6 +281,12 @@ class StableDiffusion:
|
||||
self.unet.requires_grad_(False)
|
||||
self.unet.eval()
|
||||
|
||||
# load any loras we have
|
||||
if self.model_config.lora_path is not None:
|
||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||
pipe.fuse_lora()
|
||||
self.unet.fuse_lora()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.pipeline = pipe
|
||||
|
||||
Reference in New Issue
Block a user