mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-03 01:29:50 +00:00
Added random weight adjuster to prevent overfitting
This commit is contained in:
@@ -16,8 +16,8 @@ sys.path.append(REPOS_ROOT)
|
||||
|
||||
process_dict = {
|
||||
'vae': 'TrainVAEProcess',
|
||||
'slider_dev': 'TrainSliderProcess',
|
||||
'slider': 'TrainSliderProcessOld',
|
||||
'slider': 'TrainSliderProcess',
|
||||
'slider_old': 'TrainSliderProcessOld',
|
||||
'lora_hack': 'TrainLoRAHack',
|
||||
'rescale_sd': 'TrainSDRescaleProcess',
|
||||
}
|
||||
|
||||
@@ -224,6 +224,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
neutral=cache[neutral],
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
multiplier=target.multiplier,
|
||||
both_targets=cache[f"{target.positive} {target.negative}"],
|
||||
empty_prompt=cache[""],
|
||||
weight=target.weight
|
||||
),
|
||||
@@ -240,6 +241,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
neutral=cache[neutral],
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
multiplier=target.multiplier,
|
||||
both_targets=cache[f"{target.positive} {target.negative}"],
|
||||
empty_prompt=cache[""],
|
||||
weight=target.weight
|
||||
),
|
||||
@@ -255,6 +257,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
||||
neutral=cache[neutral],
|
||||
action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
|
||||
both_targets=cache[f"{target.positive} {target.negative}"],
|
||||
empty_prompt=cache[""],
|
||||
multiplier=target.multiplier * -1.0,
|
||||
weight=target.weight
|
||||
@@ -269,6 +272,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
|
||||
negative_target=cache[f"{target.negative}"],
|
||||
negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
|
||||
both_targets=cache[f"{target.positive} {target.negative}"],
|
||||
neutral=cache[neutral],
|
||||
action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
|
||||
empty_prompt=cache[""],
|
||||
@@ -335,6 +339,9 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
def hook_train_loop(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# get random multiplier between 1 and 3
|
||||
rand_weight = torch.rand((1,)).item() * 2 + 1
|
||||
|
||||
# get a random pair
|
||||
prompt_pair: EncodedPromptPair = self.prompt_pairs[
|
||||
torch.randint(0, len(self.prompt_pairs), (1,)).item()
|
||||
@@ -373,7 +380,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
|
||||
# set network multiplier
|
||||
self.network.multiplier = multiplier
|
||||
self.network.multiplier = multiplier * rand_weight
|
||||
|
||||
with torch.no_grad():
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
@@ -399,7 +406,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
self.network.multiplier = multiplier
|
||||
self.network.multiplier = multiplier * rand_weight
|
||||
denoised_latents = self.diffuse_some_steps(
|
||||
latents, # pass simple noise latents
|
||||
train_tools.concat_prompt_embeddings(
|
||||
@@ -443,16 +450,16 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
with self.network:
|
||||
# anchor whatever weight prompt pair is using
|
||||
pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
|
||||
self.network.multiplier = anchor.multiplier * pos_nem_mult
|
||||
self.network.multiplier = anchor.multiplier * pos_nem_mult * rand_weight
|
||||
|
||||
anchor_pred_noise = get_noise_pred(
|
||||
anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
self.network.multiplier = prompt_pair.multiplier * rand_weight
|
||||
|
||||
with self.network:
|
||||
self.network.multiplier = prompt_pair.multiplier
|
||||
self.network.multiplier = prompt_pair.multiplier * rand_weight
|
||||
target_latents = get_noise_pred(
|
||||
positive, target_class, 1, current_timestep, denoised_latents
|
||||
).to("cpu", dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user