diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 97a072ae..73d7daf4 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -36,6 +36,12 @@ class TrainSliderProcess(BaseSDTrainProcess): # keep track of prompt chunk size self.prompt_chunk_size = 1 + # check if we have more targets than steps + # this can happen because of permutation son shuffling + if len(self.slider_config.targets) > self.train_config.steps: + # trim targets + self.slider_config.targets = self.slider_config.targets[:self.train_config.steps] + def before_model_load(self): pass @@ -87,7 +93,8 @@ class TrainSliderProcess(BaseSDTrainProcess): prompts_to_cache = list(dict.fromkeys(prompts_to_cache)) # trim to max steps if max steps is lower than prompt count - prompts_to_cache = prompts_to_cache[:self.train_config.steps] + # todo, this can break if we have more targets than steps, should be fixed, by reducing permuations, but could stil happen with low steps + # prompts_to_cache = prompts_to_cache[:self.train_config.steps] # encode them cache = encode_prompts_to_cache( diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index d7e9da1f..813c3bad 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -4,6 +4,7 @@ from typing import Optional, TYPE_CHECKING, List import torch from safetensors.torch import load_file, save_file from tqdm import tqdm +import random from toolkit.stable_diffusion_model import PromptEmbeds from toolkit.train_tools import get_torch_dtype @@ -248,7 +249,7 @@ def get_permutations(s): return [', '.join(permutation) for permutation in permutations] -def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['SliderTargetConfig']: +def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']: from toolkit.config_modules import SliderTargetConfig pos_permutations = get_permutations(target.positive) neg_permutations = get_permutations(target.negative) @@ -265,6 +266,12 @@ def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['Slider ) ) + # shuffle the list + random.shuffle(permutations) + + if len(permutations) > max_permutations: + permutations = permutations[:max_permutations] + return permutations