From c35b78f0d4212c8c1daeb693cb61d8c65de4312f Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 29 Jul 2023 17:14:14 -0600 Subject: [PATCH] Added random weight adjuster to prevent overfitting --- jobs/TrainJob.py | 4 ++-- jobs/process/TrainSliderProcess.py | 17 ++++++++++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index 3fc9b7db..a1939890 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -16,8 +16,8 @@ sys.path.append(REPOS_ROOT) process_dict = { 'vae': 'TrainVAEProcess', - 'slider_dev': 'TrainSliderProcess', - 'slider': 'TrainSliderProcessOld', + 'slider': 'TrainSliderProcess', + 'slider_old': 'TrainSliderProcessOld', 'lora_hack': 'TrainLoRAHack', 'rescale_sd': 'TrainSDRescaleProcess', } diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 90535ed8..1a9c8614 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -224,6 +224,7 @@ class TrainSliderProcess(BaseSDTrainProcess): neutral=cache[neutral], action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, multiplier=target.multiplier, + both_targets=cache[f"{target.positive} {target.negative}"], empty_prompt=cache[""], weight=target.weight ), @@ -240,6 +241,7 @@ class TrainSliderProcess(BaseSDTrainProcess): neutral=cache[neutral], action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, multiplier=target.multiplier, + both_targets=cache[f"{target.positive} {target.negative}"], empty_prompt=cache[""], weight=target.weight ), @@ -255,6 +257,7 @@ class TrainSliderProcess(BaseSDTrainProcess): negative_target_with_neutral=cache[f"{target.positive} {neutral}"], neutral=cache[neutral], action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE, + both_targets=cache[f"{target.positive} {target.negative}"], empty_prompt=cache[""], multiplier=target.multiplier * -1.0, weight=target.weight @@ -269,6 +272,7 @@ class TrainSliderProcess(BaseSDTrainProcess): positive_target_with_neutral=cache[f"{target.positive} {neutral}"], negative_target=cache[f"{target.negative}"], negative_target_with_neutral=cache[f"{target.negative} {neutral}"], + both_targets=cache[f"{target.positive} {target.negative}"], neutral=cache[neutral], action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE, empty_prompt=cache[""], @@ -335,6 +339,9 @@ class TrainSliderProcess(BaseSDTrainProcess): def hook_train_loop(self): dtype = get_torch_dtype(self.train_config.dtype) + # get random multiplier between 1 and 3 + rand_weight = torch.rand((1,)).item() * 2 + 1 + # get a random pair prompt_pair: EncodedPromptPair = self.prompt_pairs[ torch.randint(0, len(self.prompt_pairs), (1,)).item() @@ -373,7 +380,7 @@ class TrainSliderProcess(BaseSDTrainProcess): ) # set network multiplier - self.network.multiplier = multiplier + self.network.multiplier = multiplier * rand_weight with torch.no_grad(): self.sd.noise_scheduler.set_timesteps( @@ -399,7 +406,7 @@ class TrainSliderProcess(BaseSDTrainProcess): with self.network: assert self.network.is_active - self.network.multiplier = multiplier + self.network.multiplier = multiplier * rand_weight denoised_latents = self.diffuse_some_steps( latents, # pass simple noise latents train_tools.concat_prompt_embeddings( @@ -443,16 +450,16 @@ class TrainSliderProcess(BaseSDTrainProcess): with self.network: # anchor whatever weight prompt pair is using pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0 - self.network.multiplier = anchor.multiplier * pos_nem_mult + self.network.multiplier = anchor.multiplier * pos_nem_mult * rand_weight anchor_pred_noise = get_noise_pred( anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents ).to("cpu", dtype=torch.float32) - self.network.multiplier = prompt_pair.multiplier + self.network.multiplier = prompt_pair.multiplier * rand_weight with self.network: - self.network.multiplier = prompt_pair.multiplier + self.network.multiplier = prompt_pair.multiplier * rand_weight target_latents = get_noise_pred( positive, target_class, 1, current_timestep, denoised_latents ).to("cpu", dtype=torch.float32)