mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added shuffeling to prompts
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List, Optional
|
||||
import random
|
||||
|
||||
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.save_every: int = kwargs.get('save_every', 1000)
|
||||
@@ -93,6 +94,7 @@ class SliderTargetConfig:
|
||||
self.negative: str = kwargs.get('negative', '')
|
||||
self.multiplier: float = kwargs.get('multiplier', 1.0)
|
||||
self.weight: float = kwargs.get('weight', 1.0)
|
||||
self.shuffle: bool = kwargs.get('shuffle', False)
|
||||
|
||||
|
||||
class SliderConfigAnchors:
|
||||
@@ -105,8 +107,6 @@ class SliderConfigAnchors:
|
||||
class SliderConfig:
|
||||
def __init__(self, **kwargs):
|
||||
targets = kwargs.get('targets', [])
|
||||
targets = [SliderTargetConfig(**target) for target in targets]
|
||||
self.targets: List[SliderTargetConfig] = targets
|
||||
anchors = kwargs.get('anchors', [])
|
||||
anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
|
||||
self.anchors: List[SliderConfigAnchors] = anchors
|
||||
@@ -115,6 +115,18 @@ class SliderConfig:
|
||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
|
||||
# expand targets if shuffling
|
||||
from toolkit.prompt_utils import get_slider_target_permutations
|
||||
self.targets: List[SliderTargetConfig] = []
|
||||
targets = [SliderTargetConfig(**target) for target in targets]
|
||||
# do permutations if shuffle is true
|
||||
for target in targets:
|
||||
if target.shuffle:
|
||||
target_permutations = get_slider_target_permutations(target)
|
||||
self.targets = self.targets + target_permutations
|
||||
else:
|
||||
self.targets.append(target)
|
||||
|
||||
|
||||
class GenerateImageConfig:
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user