Fixed issue with shuffeling permutations

This commit is contained in:
Jaret Burkett
2023-08-23 22:02:00 -06:00
parent b408f9f3eb
commit aeaca13d69
2 changed files with 16 additions and 2 deletions

View File

@@ -4,6 +4,7 @@ from typing import Optional, TYPE_CHECKING, List
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import random
from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
@@ -248,7 +249,7 @@ def get_permutations(s):
return [', '.join(permutation) for permutation in permutations]
def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['SliderTargetConfig']:
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
from toolkit.config_modules import SliderTargetConfig
pos_permutations = get_permutations(target.positive)
neg_permutations = get_permutations(target.negative)
@@ -265,6 +266,12 @@ def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['Slider
)
)
# shuffle the list
random.shuffle(permutations)
if len(permutations) > max_permutations:
permutations = permutations[:max_permutations]
return permutations