mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-02 20:07:53 +00:00
Fixed issue with shuffeling permutations
This commit is contained in:
@@ -36,6 +36,12 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
# keep track of prompt chunk size
|
||||
self.prompt_chunk_size = 1
|
||||
|
||||
# check if we have more targets than steps
|
||||
# this can happen because of permutation son shuffling
|
||||
if len(self.slider_config.targets) > self.train_config.steps:
|
||||
# trim targets
|
||||
self.slider_config.targets = self.slider_config.targets[:self.train_config.steps]
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
@@ -87,7 +93,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
prompts_to_cache = list(dict.fromkeys(prompts_to_cache))
|
||||
|
||||
# trim to max steps if max steps is lower than prompt count
|
||||
prompts_to_cache = prompts_to_cache[:self.train_config.steps]
|
||||
# todo, this can break if we have more targets than steps, should be fixed, by reducing permuations, but could stil happen with low steps
|
||||
# prompts_to_cache = prompts_to_cache[:self.train_config.steps]
|
||||
|
||||
# encode them
|
||||
cache = encode_prompts_to_cache(
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional, TYPE_CHECKING, List
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
@@ -248,7 +249,7 @@ def get_permutations(s):
|
||||
return [', '.join(permutation) for permutation in permutations]
|
||||
|
||||
|
||||
def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['SliderTargetConfig']:
|
||||
def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
|
||||
from toolkit.config_modules import SliderTargetConfig
|
||||
pos_permutations = get_permutations(target.positive)
|
||||
neg_permutations = get_permutations(target.negative)
|
||||
@@ -265,6 +266,12 @@ def get_slider_target_permutations(target: 'SliderTargetConfig') -> List['Slider
|
||||
)
|
||||
)
|
||||
|
||||
# shuffle the list
|
||||
random.shuffle(permutations)
|
||||
|
||||
if len(permutations) > max_permutations:
|
||||
permutations = permutations[:max_permutations]
|
||||
|
||||
return permutations
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user