Fixed issue with shuffeling permutations

This commit is contained in:
Jaret Burkett
2023-08-23 22:02:00 -06:00
parent b408f9f3eb
commit aeaca13d69
2 changed files with 16 additions and 2 deletions

View File

@@ -36,6 +36,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
# keep track of prompt chunk size
self.prompt_chunk_size = 1
# check if we have more targets than steps
# this can happen because of permutation son shuffling
if len(self.slider_config.targets) > self.train_config.steps:
# trim targets
self.slider_config.targets = self.slider_config.targets[:self.train_config.steps]
def before_model_load(self):
pass
@@ -87,7 +93,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
prompts_to_cache = list(dict.fromkeys(prompts_to_cache))
# trim to max steps if max steps is lower than prompt count
prompts_to_cache = prompts_to_cache[:self.train_config.steps]
# todo, this can break if we have more targets than steps, should be fixed, by reducing permuations, but could stil happen with low steps
# prompts_to_cache = prompts_to_cache[:self.train_config.steps]
# encode them
cache = encode_prompts_to_cache(

View File

@@ -4,6 +4,7 @@ from typing import Optional, TYPE_CHECKING, List
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import random
from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
@@ -248,7 +249,7 @@ def get_permutations(s):
return [', '.join(permutation) for permutation in permutations]
def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['SliderTargetConfig']:
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
from toolkit.config_modules import SliderTargetConfig
pos_permutations = get_permutations(target.positive)
neg_permutations = get_permutations(target.negative)
@@ -265,6 +266,12 @@ def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['Slider
)
)
# shuffle the list
random.shuffle(permutations)
if len(permutations) > max_permutations:
permutations = permutations[:max_permutations]
return permutations