mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Fixed issue with shuffeling permutations
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Optional, TYPE_CHECKING, List
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
@@ -248,7 +249,7 @@ def get_permutations(s):
|
||||
return [', '.join(permutation) for permutation in permutations]
|
||||
|
||||
|
||||
def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['SliderTargetConfig']:
|
||||
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
|
||||
from toolkit.config_modules import SliderTargetConfig
|
||||
pos_permutations = get_permutations(target.positive)
|
||||
neg_permutations = get_permutations(target.negative)
|
||||
@@ -265,6 +266,12 @@ def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['Slider
|
||||
)
|
||||
)
|
||||
|
||||
# shuffle the list
|
||||
random.shuffle(permutations)
|
||||
|
||||
if len(permutations) > max_permutations:
|
||||
permutations = permutations[:max_permutations]
|
||||
|
||||
return permutations
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user