improved correction of pred norm by targeting the prior

This commit is contained in:
Jaret Burkett
2024-02-01 06:31:04 -07:00
parent 1ae1017748
commit 177c7130ec
3 changed files with 26 additions and 26 deletions

View File

@@ -227,37 +227,30 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.correct_pred_norm and not is_reg:
with torch.no_grad():
# this only works if doing a prior pred
if prior_pred is not None:
prior_mean = prior_pred.mean([2,3], keepdim=True)
prior_std = prior_pred.std([2,3], keepdim=True)
noise_mean = noise_pred.mean([2,3], keepdim=True)
noise_std = noise_pred.std([2,3], keepdim=True)
# adjust the noise target in the opposite direction of the noise pred mean and std offset
# this will apply additional force the model to correct itself to match the norm of the noise
noise_pred_mean, noise_pred_std = get_mean_std(noise_pred)
noise_mean, noise_std = get_mean_std(noise)
mean_adjust = prior_mean - noise_mean
std_adjust = prior_std - noise_std
# apply the inverse offset of the mean and std to the noise
noise_additional_mean = noise_mean - noise_pred_mean
noise_additional_std = noise_std - noise_pred_std
mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier
std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier
# adjust for multiplier
noise_additional_mean = noise_additional_mean * self.train_config.correct_pred_norm_multiplier
noise_additional_std = noise_additional_std * self.train_config.correct_pred_norm_multiplier
target_mean = noise_mean + mean_adjust
target_std = noise_std + std_adjust
noise_target_std = noise_std + noise_additional_std
noise_target_mean = noise_mean + noise_additional_mean
eps = 1e-5
noise_pred_target_std = noise_pred_std - noise_additional_std
noise_pred_target_mean = noise_pred_mean - noise_additional_mean
noise_pred_target_std = noise_pred_target_std.detach()
noise_pred_target_mean = noise_pred_target_mean.detach()
# match the noise to the target
noise = (noise - noise_mean) / noise_std
noise = noise * noise_target_std + noise_target_mean
noise = noise.detach()
# meatch the noise pred to the target
# noise_pred = (noise_pred - noise_pred_mean) / noise_pred_std
# noise_pred = noise_pred * noise_pred_target_std + noise_pred_target_mean
# adjust the noise target to match the current knowledge of the model
# noise_mean, noise_std = get_mean_std(noise)
# match the noise to the prior
noise = (noise - noise_mean) / (noise_std + eps)
noise = noise * (target_std + eps) + target_mean
noise = noise.detach()
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo

View File

@@ -306,6 +306,7 @@ class ModelConfig:
self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None)
self._original_refiner_name_or_path = self.refiner_name_or_path
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
self.lora_path = kwargs.get('lora_path', None)
# only for SDXL models for now
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)

View File

@@ -281,6 +281,12 @@ class StableDiffusion:
self.unet.requires_grad_(False)
self.unet.eval()
# load any loras we have
if self.model_config.lora_path is not None:
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
pipe.fuse_lora()
self.unet.fuse_lora()
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.pipeline = pipe