Added random weight adjuster to prevent overfitting

This commit is contained in:
Jaret Burkett
2023-07-29 17:14:14 -06:00
parent 8ba1b11557
commit c35b78f0d4
2 changed files with 14 additions and 7 deletions

View File

@@ -16,8 +16,8 @@ sys.path.append(REPOS_ROOT)
process_dict = {
'vae': 'TrainVAEProcess',
'slider_dev': 'TrainSliderProcess',
'slider': 'TrainSliderProcessOld',
'slider': 'TrainSliderProcess',
'slider_old': 'TrainSliderProcessOld',
'lora_hack': 'TrainLoRAHack',
'rescale_sd': 'TrainSDRescaleProcess',
}

View File

@@ -224,6 +224,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
multiplier=target.multiplier,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
weight=target.weight
),
@@ -240,6 +241,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
multiplier=target.multiplier,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
weight=target.weight
),
@@ -255,6 +257,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
both_targets=cache[f"{target.positive} {target.negative}"],
empty_prompt=cache[""],
multiplier=target.multiplier * -1.0,
weight=target.weight
@@ -269,6 +272,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
negative_target=cache[f"{target.negative}"],
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
both_targets=cache[f"{target.positive} {target.negative}"],
neutral=cache[neutral],
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
empty_prompt=cache[""],
@@ -335,6 +339,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
def hook_train_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
# get random multiplier between 1 and 3
rand_weight = torch.rand((1,)).item() * 2 + 1
# get a random pair
prompt_pair: EncodedPromptPair = self.prompt_pairs[
torch.randint(0, len(self.prompt_pairs), (1,)).item()
@@ -373,7 +380,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
# set network multiplier
self.network.multiplier = multiplier
self.network.multiplier = multiplier * rand_weight
with torch.no_grad():
self.sd.noise_scheduler.set_timesteps(
@@ -399,7 +406,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network:
assert self.network.is_active
self.network.multiplier = multiplier
self.network.multiplier = multiplier * rand_weight
denoised_latents = self.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
@@ -443,16 +450,16 @@ class TrainSliderProcess(BaseSDTrainProcess):
with self.network:
# anchor whatever weight prompt pair is using
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
self.network.multiplier = anchor.multiplier * pos_nem_mult
self.network.multiplier = anchor.multiplier * pos_nem_mult * rand_weight
anchor_pred_noise = get_noise_pred(
anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents
).to("cpu", dtype=torch.float32)
self.network.multiplier = prompt_pair.multiplier
self.network.multiplier = prompt_pair.multiplier * rand_weight
with self.network:
self.network.multiplier = prompt_pair.multiplier
self.network.multiplier = prompt_pair.multiplier * rand_weight
target_latents = get_noise_pred(
positive, target_class, 1, current_timestep, denoised_latents
).to("cpu", dtype=torch.float32)