mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 17:29:27 +00:00
Hude rework to move the batch to a DTO to make it far more modular to the future ui
This commit is contained in:
@@ -5,6 +5,7 @@ import random
|
||||
|
||||
ImgExt = Literal['jpg', 'png', 'webp']
|
||||
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.save_every: int = kwargs.get('save_every', 1000)
|
||||
@@ -167,9 +168,13 @@ class DatasetConfig:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.type = kwargs.get('type', 'image') # sd, slider, reference
|
||||
# will be legacy
|
||||
self.folder_path: str = kwargs.get('folder_path', None)
|
||||
# can be json or folder path
|
||||
self.dataset_path: str = kwargs.get('dataset_path', None)
|
||||
|
||||
self.default_caption: str = kwargs.get('default_caption', None)
|
||||
self.caption_type: str = kwargs.get('caption_type', None)
|
||||
self.caption_ext: str = kwargs.get('caption_ext', None)
|
||||
self.random_scale: bool = kwargs.get('random_scale', False)
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
self.resolution: int = kwargs.get('resolution', 512)
|
||||
@@ -182,6 +187,33 @@ class DatasetConfig:
|
||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||
|
||||
# legacy compatability
|
||||
legacy_caption_type = kwargs.get('caption_type', None)
|
||||
if legacy_caption_type:
|
||||
self.caption_ext = legacy_caption_type
|
||||
self.caption_type = self.caption_ext
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
"""
|
||||
This just splits up the datasets by resolutions so you dont have to do it manually
|
||||
:param raw_config:
|
||||
:return:
|
||||
"""
|
||||
# split up datasets by resolutions
|
||||
new_config = []
|
||||
for dataset in raw_config:
|
||||
resolution = dataset.get('resolution', 512)
|
||||
if isinstance(resolution, list):
|
||||
resolution_list = resolution
|
||||
else:
|
||||
resolution_list = [resolution]
|
||||
for res in resolution_list:
|
||||
dataset_copy = dataset.copy()
|
||||
dataset_copy['resolution'] = res
|
||||
new_config.append(dataset_copy)
|
||||
return new_config
|
||||
|
||||
|
||||
class GenerateImageConfig:
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user