mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-05 13:09:57 +00:00
Hude rework to move the batch to a DTO to make it far more modular to the future ui
This commit is contained in:
@@ -35,6 +35,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
def hook_train_loop(self, batch):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
flush()
|
||||
@@ -53,6 +54,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
|
||||
# set the weights
|
||||
network.multiplier = network_weight_list
|
||||
|
||||
# activate network if it exits
|
||||
with network:
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
@@ -114,5 +118,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
)
|
||||
# reset network multiplier
|
||||
network.multiplier = 1.0
|
||||
|
||||
return loss_dict
|
||||
|
||||
@@ -19,7 +19,7 @@ config:
|
||||
max_step_saves_to_keep: 5 # only affects step counts
|
||||
datasets:
|
||||
- folder_path: "/path/to/dataset"
|
||||
caption_type: "txt"
|
||||
caption_ext: "txt"
|
||||
default_caption: "[trigger]"
|
||||
buckets: true
|
||||
resolution: 512
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Union
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -23,7 +24,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config
|
||||
|
||||
|
||||
def flush():
|
||||
@@ -67,6 +68,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.trigger_word = self.get_conf('trigger_word', None)
|
||||
|
||||
raw_datasets = self.get_conf('datasets', None)
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
||||
self.datasets = None
|
||||
self.datasets_reg = None
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
@@ -94,6 +97,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.model_config.name_or_path = latest_save_path
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
@@ -307,16 +316,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def process_general_training_batch(self, batch):
|
||||
with torch.no_grad():
|
||||
imgs, prompts, dataset_config = batch
|
||||
|
||||
# convert the 0 or 1 for is reg to a bool list
|
||||
if isinstance(dataset_config, list):
|
||||
is_reg_list = [x.get('is_reg', 0) for x in dataset_config]
|
||||
else:
|
||||
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
|
||||
if isinstance(is_reg_list, torch.Tensor):
|
||||
is_reg_list = is_reg_list.numpy().tolist()
|
||||
is_reg_list = [bool(x) for x in is_reg_list]
|
||||
imgs = batch.tensor
|
||||
prompts = batch.get_caption_list()
|
||||
is_reg_list = batch.get_is_reg_list()
|
||||
|
||||
conditioned_prompts = []
|
||||
|
||||
@@ -473,6 +475,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# resume state from embedding
|
||||
self.step_num = self.embedding.step
|
||||
self.start_step = self.step_num
|
||||
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
@@ -556,13 +559,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
with torch.no_grad():
|
||||
# if is even step and we have a reg dataset, use that
|
||||
# todo improve this logic to send one of each through if we can buckets and batch size might be an issue
|
||||
if step % 2 == 0 and dataloader_reg is not None:
|
||||
is_reg_step = False
|
||||
is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0
|
||||
is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
|
||||
# don't do a reg step on sample or save steps as we dont want to normalize on those
|
||||
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
|
||||
try:
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
is_reg_step = True
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
@@ -601,11 +609,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.step_num != self.start_step:
|
||||
# pause progress bar
|
||||
self.progress_bar.unpause() # makes it so doesn't track time
|
||||
if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0:
|
||||
if is_sample_step:
|
||||
# print above the progress bar
|
||||
self.sample(self.step_num)
|
||||
|
||||
if self.save_config.save_every and self.step_num % self.save_config.save_every == 0:
|
||||
if is_save_step:
|
||||
# print above the progress bar
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
@@ -623,10 +631,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# end of step
|
||||
self.step_num = step
|
||||
|
||||
# apply network normalizer if we are using it
|
||||
if self.network is not None and self.network.is_normalizing:
|
||||
# apply network normalizer if we are using it, not on regularization steps
|
||||
if self.network is not None and self.network.is_normalizing and not is_reg_step:
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
batch.cleanup()
|
||||
|
||||
self.sample(self.step_num + 1)
|
||||
print("")
|
||||
self.save()
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
# ref:
|
||||
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from leco import train_util, model_util
|
||||
from leco.prompt_util import PromptEmbedsCache
|
||||
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class LoRAHack:
|
||||
def __init__(self, **kwargs):
|
||||
self.type = kwargs.get('type', 'suppression')
|
||||
|
||||
|
||||
class TrainLoRAHack(BaseSDTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.hack_config = LoRAHack(**self.get_conf('hack', {}))
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
# we don't need text encoder so move it to cpu
|
||||
self.sd.text_encoder.to("cpu")
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
|
||||
if self.hack_config.type == 'suppression':
|
||||
# set all params to self.current_suppression
|
||||
params = self.network.parameters()
|
||||
for param in params:
|
||||
# get random noise for each param
|
||||
noise = torch.randn_like(param) - 0.5
|
||||
# apply noise to param
|
||||
param.data = noise * 0.001
|
||||
|
||||
|
||||
def supress_loop(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'sup': 0.0}
|
||||
)
|
||||
# increase noise
|
||||
for param in self.network.parameters():
|
||||
# get random noise for each param
|
||||
noise = torch.randn_like(param) - 0.5
|
||||
# apply noise to param
|
||||
param.data = param.data + noise * 0.001
|
||||
|
||||
|
||||
|
||||
return loss_dict
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
if self.hack_config.type == 'suppression':
|
||||
return self.supress_loop()
|
||||
else:
|
||||
raise NotImplementedError(f'unknown hack type: {self.hack_config.type}')
|
||||
# end hook_train_loop
|
||||
@@ -7,7 +7,6 @@ from .TrainVAEProcess import TrainVAEProcess
|
||||
from .BaseMergeProcess import BaseMergeProcess
|
||||
from .TrainSliderProcess import TrainSliderProcess
|
||||
from .TrainSliderProcessOld import TrainSliderProcessOld
|
||||
from .TrainLoRAHack import TrainLoRAHack
|
||||
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
from .ModRescaleLoraProcess import ModRescaleLoraProcess
|
||||
from .GenerateProcess import GenerateProcess
|
||||
|
||||
@@ -22,7 +22,7 @@ batch_size = 4
|
||||
dataset_config = DatasetConfig(
|
||||
folder_path=dataset_folder,
|
||||
resolution=resolution,
|
||||
caption_type='txt',
|
||||
caption_ext='txt',
|
||||
default_caption='default',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
|
||||
@@ -5,6 +5,7 @@ import random
|
||||
|
||||
ImgExt = Literal['jpg', 'png', 'webp']
|
||||
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.save_every: int = kwargs.get('save_every', 1000)
|
||||
@@ -167,9 +168,13 @@ class DatasetConfig:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.type = kwargs.get('type', 'image') # sd, slider, reference
|
||||
# will be legacy
|
||||
self.folder_path: str = kwargs.get('folder_path', None)
|
||||
# can be json or folder path
|
||||
self.dataset_path: str = kwargs.get('dataset_path', None)
|
||||
|
||||
self.default_caption: str = kwargs.get('default_caption', None)
|
||||
self.caption_type: str = kwargs.get('caption_type', None)
|
||||
self.caption_ext: str = kwargs.get('caption_ext', None)
|
||||
self.random_scale: bool = kwargs.get('random_scale', False)
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
self.resolution: int = kwargs.get('resolution', 512)
|
||||
@@ -182,6 +187,33 @@ class DatasetConfig:
|
||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||
|
||||
# legacy compatability
|
||||
legacy_caption_type = kwargs.get('caption_type', None)
|
||||
if legacy_caption_type:
|
||||
self.caption_ext = legacy_caption_type
|
||||
self.caption_type = self.caption_ext
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
"""
|
||||
This just splits up the datasets by resolutions so you dont have to do it manually
|
||||
:param raw_config:
|
||||
:return:
|
||||
"""
|
||||
# split up datasets by resolutions
|
||||
new_config = []
|
||||
for dataset in raw_config:
|
||||
resolution = dataset.get('resolution', 512)
|
||||
if isinstance(resolution, list):
|
||||
resolution_list = resolution
|
||||
else:
|
||||
resolution_list = [resolution]
|
||||
for res in resolution_list:
|
||||
dataset_copy = dataset.copy()
|
||||
dataset_copy['resolution'] = res
|
||||
new_config.append(dataset_copy)
|
||||
return new_config
|
||||
|
||||
|
||||
class GenerateImageConfig:
|
||||
def __init__(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
@@ -13,10 +14,9 @@ from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
@@ -29,7 +29,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
self.include_prompt = self.get_config('include_prompt', False)
|
||||
self.default_prompt = self.get_config('default_prompt', '')
|
||||
if self.include_prompt:
|
||||
self.caption_type = self.get_config('caption_type', 'txt')
|
||||
self.caption_type = self.get_config('caption_ext', 'txt')
|
||||
else:
|
||||
self.caption_type = None
|
||||
# we always random crop if random scale is enabled
|
||||
@@ -288,24 +288,17 @@ class PairedImageDataset(Dataset):
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
printed_messages = []
|
||||
|
||||
|
||||
def print_once(msg):
|
||||
global printed_messages
|
||||
if msg not in printed_messages:
|
||||
print(msg)
|
||||
printed_messages.append(msg)
|
||||
|
||||
|
||||
|
||||
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
|
||||
def __init__(self, dataset_config: 'DatasetConfig', batch_size=1):
|
||||
super().__init__()
|
||||
self.dataset_config = dataset_config
|
||||
self.folder_path = dataset_config.folder_path
|
||||
self.caption_type = dataset_config.caption_type
|
||||
folder_path = dataset_config.folder_path
|
||||
self.dataset_path = dataset_config.dataset_path
|
||||
if self.dataset_path is None:
|
||||
self.dataset_path = folder_path
|
||||
|
||||
self.caption_type = dataset_config.caption_ext
|
||||
self.default_caption = dataset_config.default_caption
|
||||
self.random_scale = dataset_config.random_scale
|
||||
self.scale = dataset_config.scale
|
||||
@@ -313,147 +306,96 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
# we always random crop if random scale is enabled
|
||||
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
|
||||
self.resolution = dataset_config.resolution
|
||||
self.caption_dict = None
|
||||
self.file_list: List['FileItemDTO'] = []
|
||||
|
||||
# get the file list
|
||||
file_list = [
|
||||
os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||
]
|
||||
# check if dataset_path is a folder or json
|
||||
if os.path.isdir(self.dataset_path):
|
||||
file_list = [
|
||||
os.path.join(self.dataset_path, file) for file in os.listdir(self.dataset_path) if
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||
]
|
||||
else:
|
||||
# assume json
|
||||
with open(self.dataset_path, 'r') as f:
|
||||
self.caption_dict = json.load(f)
|
||||
# keys are file paths
|
||||
file_list = list(self.caption_dict.keys())
|
||||
|
||||
# this might take a while
|
||||
print(f" - Preprocessing image dimensions")
|
||||
bad_count = 0
|
||||
for file in tqdm(file_list):
|
||||
try:
|
||||
w, h = image_utils.get_image_size(file)
|
||||
except image_utils.UnknownImageFormat:
|
||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||
f'This process is faster for png, jpeg')
|
||||
img = Image.open(file)
|
||||
h, w = img.size
|
||||
# TODO allow smaller images
|
||||
if int(min(h, w) * self.scale) >= self.resolution:
|
||||
self.file_list.append(
|
||||
FileItemDTO(
|
||||
path=file,
|
||||
width=w,
|
||||
height=h,
|
||||
scale_to_width=int(w * self.scale),
|
||||
scale_to_height=int(h * self.scale),
|
||||
dataset_config=dataset_config
|
||||
)
|
||||
)
|
||||
else:
|
||||
file_item = FileItemDTO(
|
||||
path=file,
|
||||
dataset_config=dataset_config
|
||||
)
|
||||
if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
|
||||
bad_count += 1
|
||||
else:
|
||||
self.file_list.append(file_item)
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.folder_path}"
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
self.setup_epoch()
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
def setup_epoch(self):
|
||||
# TODO: set this up to redo cropping and everything else
|
||||
# do not call for now
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
|
||||
def __len__(self):
|
||||
if self.dataset_config.buckets:
|
||||
return len(self.batch_indices)
|
||||
return len(self.file_list)
|
||||
|
||||
def _get_single_item(self, index):
|
||||
def _get_single_item(self, index) -> 'FileItemDTO':
|
||||
file_item = self.file_list[index]
|
||||
# todo make sure this matches
|
||||
img = exif_transpose(Image.open(file_item.path)).convert('RGB')
|
||||
w, h = img.size
|
||||
if w > h and file_item.scale_to_width < file_item.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
|
||||
elif h > w and file_item.scale_to_height < file_item.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
|
||||
min_img_size = min(img.size)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# todo allow scaling and cropping, will be hard to add
|
||||
# scale and crop based on file item
|
||||
img = img.resize((file_item.scale_to_width, file_item.scale_to_height), Image.BICUBIC)
|
||||
img = transforms.CenterCrop((file_item.crop_height, file_item.crop_width))(img)
|
||||
else:
|
||||
if self.random_crop:
|
||||
if self.random_scale and min_img_size > self.resolution:
|
||||
if min_img_size < self.resolution:
|
||||
print(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file_item.path}")
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||
img = transforms.RandomCrop(self.resolution)(img)
|
||||
else:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
||||
|
||||
img = self.transform(img)
|
||||
|
||||
# todo convert it all
|
||||
dataset_config_dict = {
|
||||
"is_reg": 1 if self.dataset_config.is_reg else 0,
|
||||
}
|
||||
|
||||
if self.caption_type is not None:
|
||||
prompt = self.get_caption_item(index)
|
||||
return img, prompt, dataset_config_dict
|
||||
else:
|
||||
return img, dataset_config_dict
|
||||
file_item.load_and_process_image(self.transform)
|
||||
file_item.load_caption(self.caption_dict)
|
||||
return file_item
|
||||
|
||||
def __getitem__(self, item):
|
||||
if self.dataset_config.buckets:
|
||||
# for buckets we collate ourselves for now
|
||||
# todo allow a scheduler to dynamically make buckets
|
||||
# we collate ourselves
|
||||
idx_list = self.batch_indices[item]
|
||||
tensor_list = []
|
||||
prompt_list = []
|
||||
dataset_config_dict_list = []
|
||||
for idx in idx_list:
|
||||
if self.caption_type is not None:
|
||||
img, prompt, dataset_config_dict = self._get_single_item(idx)
|
||||
prompt_list.append(prompt)
|
||||
dataset_config_dict_list.append(dataset_config_dict)
|
||||
else:
|
||||
img, dataset_config_dict = self._get_single_item(idx)
|
||||
dataset_config_dict_list.append(dataset_config_dict)
|
||||
tensor_list.append(img.unsqueeze(0))
|
||||
|
||||
if self.caption_type is not None:
|
||||
return torch.cat(tensor_list, dim=0), prompt_list, dataset_config_dict_list
|
||||
else:
|
||||
return torch.cat(tensor_list, dim=0), dataset_config_dict_list
|
||||
return [self._get_single_item(idx) for idx in idx_list]
|
||||
else:
|
||||
# Dataloader is batching
|
||||
return self._get_single_item(item)
|
||||
|
||||
|
||||
def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||
# TODO do bucketing
|
||||
if dataset_options is None or len(dataset_options) == 0:
|
||||
return None
|
||||
|
||||
datasets = []
|
||||
has_buckets = False
|
||||
|
||||
dataset_config_list = []
|
||||
# preprocess them all
|
||||
for dataset_option in dataset_options:
|
||||
if isinstance(dataset_option, DatasetConfig):
|
||||
config = dataset_option
|
||||
dataset_config_list.append(dataset_option)
|
||||
else:
|
||||
config = DatasetConfig(**dataset_option)
|
||||
# preprocess raw data
|
||||
split_configs = preprocess_dataset_raw_config([dataset_option])
|
||||
for x in split_configs:
|
||||
dataset_config_list.append(DatasetConfig(**x))
|
||||
|
||||
for config in dataset_config_list:
|
||||
|
||||
if config.type == 'image':
|
||||
dataset = AiToolkitDataset(config, batch_size=batch_size)
|
||||
datasets.append(dataset)
|
||||
@@ -463,21 +405,28 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||
raise ValueError(f"invalid dataset type: {config.type}")
|
||||
|
||||
concatenated_dataset = ConcatDataset(datasets)
|
||||
|
||||
# todo build scheduler that can get buckets from all datasets that match
|
||||
# todo and evenly distribute reg images
|
||||
|
||||
def dto_collation(batch: List['FileItemDTO']):
|
||||
# create DTO batch
|
||||
batch = DataLoaderBatchDTO(
|
||||
file_items=batch
|
||||
)
|
||||
return batch
|
||||
|
||||
if has_buckets:
|
||||
# make sure they all have buckets
|
||||
for dataset in datasets:
|
||||
assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none"
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
# just return as is
|
||||
return batch
|
||||
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=None, # we batch in the dataloader
|
||||
batch_size=None, # we batch in the datasets for now
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
collate_fn=custom_collate_fn, # Use the custom collate function
|
||||
collate_fn=dto_collation, # Use the custom collate function
|
||||
num_workers=2
|
||||
)
|
||||
else:
|
||||
@@ -485,6 +434,7 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||
concatenated_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=2
|
||||
num_workers=2,
|
||||
collate_fn=dto_collation
|
||||
)
|
||||
return data_loader
|
||||
|
||||
@@ -1,36 +1,84 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
import torch
|
||||
import random
|
||||
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
|
||||
printed_messages = []
|
||||
|
||||
class FileItemDTO(CaptionProcessingDTOMixin):
|
||||
|
||||
def print_once(msg):
|
||||
global printed_messages
|
||||
if msg not in printed_messages:
|
||||
print(msg)
|
||||
printed_messages.append(msg)
|
||||
|
||||
|
||||
class FileItemDTO(CaptionProcessingDTOMixin, ImageProcessingDTOMixin):
|
||||
def __init__(self, **kwargs):
|
||||
self.path = kwargs.get('path', None)
|
||||
self.caption_path: str = kwargs.get('caption_path', None)
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
# process width and height
|
||||
try:
|
||||
w, h = image_utils.get_image_size(self.path)
|
||||
except image_utils.UnknownImageFormat:
|
||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||
f'This process is faster for png, jpeg')
|
||||
img = exif_transpose(Image.open(self.path))
|
||||
h, w = img.size
|
||||
self.width: int = w
|
||||
self.height: int = h
|
||||
|
||||
# self.caption_path: str = kwargs.get('caption_path', None)
|
||||
self.raw_caption: str = kwargs.get('raw_caption', None)
|
||||
self.width: int = kwargs.get('width', None)
|
||||
self.height: int = kwargs.get('height', None)
|
||||
# we scale first, then crop
|
||||
self.scale_to_width: int = kwargs.get('scale_to_width', self.width)
|
||||
self.scale_to_height: int = kwargs.get('scale_to_height', self.height)
|
||||
self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale))
|
||||
self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale))
|
||||
# crop values are from scaled size
|
||||
self.crop_x: int = kwargs.get('crop_x', 0)
|
||||
self.crop_y: int = kwargs.get('crop_y', 0)
|
||||
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
|
||||
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
|
||||
|
||||
# process config
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.network_weight: float = self.dataset_config.network_weight
|
||||
self.is_reg = self.dataset_config.is_reg
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
|
||||
self.network_network_weight: float = self.dataset_config.network_weight
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
|
||||
|
||||
class DataLoaderBatchDTO:
|
||||
def __init__(self, **kwargs):
|
||||
self.file_item: 'FileItemDTO' = kwargs.get('file_item', None)
|
||||
self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
|
||||
def get_is_reg_list(self):
|
||||
return [x.is_reg for x in self.file_items]
|
||||
|
||||
def get_network_weight_list(self):
|
||||
return [x.network_weight for x in self.file_items]
|
||||
|
||||
def get_caption_list(
|
||||
self,
|
||||
trigger=None,
|
||||
to_replace_list=None,
|
||||
add_if_not_present=True
|
||||
):
|
||||
return [x.get_caption(
|
||||
trigger=trigger,
|
||||
to_replace_list=to_replace_list,
|
||||
add_if_not_present=add_if_not_present
|
||||
) for x in self.file_items]
|
||||
|
||||
def cleanup(self):
|
||||
self.tensor = None
|
||||
for file_item in self.file_items:
|
||||
file_item.cleanup()
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import os
|
||||
import random
|
||||
from typing import TYPE_CHECKING, List, Dict
|
||||
from typing import TYPE_CHECKING, List, Dict, Union
|
||||
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.data_loader import AiToolkitDataset
|
||||
@@ -159,6 +162,38 @@ class BucketsMixin:
|
||||
|
||||
|
||||
class CaptionProcessingDTOMixin:
|
||||
|
||||
# todo allow for loading from sd-scripts style dict
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
|
||||
if self.raw_caption is not None:
|
||||
# we already loaded it
|
||||
pass
|
||||
elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]:
|
||||
self.raw_caption = caption_dict[self.path]["caption"]
|
||||
else:
|
||||
# see if prompt file exists
|
||||
path_no_ext = os.path.splitext(self.path)[0]
|
||||
prompt_ext = self.dataset_config.caption_ext
|
||||
prompt_path = f"{path_no_ext}.{prompt_ext}"
|
||||
|
||||
if os.path.exists(prompt_path):
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt = f.read()
|
||||
# remove any newlines
|
||||
prompt = prompt.replace('\n', ', ')
|
||||
# remove new lines for all operating systems
|
||||
prompt = prompt.replace('\r', ', ')
|
||||
prompt_split = prompt.split(',')
|
||||
# remove empty strings
|
||||
prompt_split = [p.strip() for p in prompt_split if p.strip()]
|
||||
# join back together
|
||||
prompt = ', '.join(prompt_split)
|
||||
else:
|
||||
prompt = ''
|
||||
if self.dataset_config.default_caption is not None:
|
||||
prompt = self.dataset_config.default_caption
|
||||
self.raw_caption = prompt
|
||||
|
||||
def get_caption(
|
||||
self: 'FileItemDTO',
|
||||
trigger=None,
|
||||
@@ -201,3 +236,51 @@ class CaptionProcessingDTOMixin:
|
||||
caption = ', '.join(token_list)
|
||||
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||
return caption
|
||||
|
||||
|
||||
class ImageProcessingDTOMixin:
|
||||
def load_and_process_image(
|
||||
self: 'FileItemDTO',
|
||||
transform: Union[None, transforms.Compose]
|
||||
):
|
||||
# todo make sure this matches
|
||||
img = exif_transpose(Image.open(self.path)).convert('RGB')
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# todo allow scaling and cropping, will be hard to add
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
else:
|
||||
# Downscale the source image first
|
||||
img = img.resize(
|
||||
(int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)),
|
||||
Image.BICUBIC)
|
||||
min_img_size = min(img.size)
|
||||
if self.dataset_config.random_crop:
|
||||
if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution:
|
||||
if min_img_size < self.dataset_config.resolution:
|
||||
print(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}")
|
||||
scale_size = self.dataset_config.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.dataset_config.resolution, int(min_img_size))
|
||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||
img = transforms.RandomCrop(self.dataset_config.resolution)(img)
|
||||
else:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC)
|
||||
|
||||
if transform:
|
||||
img = transform(img)
|
||||
|
||||
self.tensor = img
|
||||
|
||||
Reference in New Issue
Block a user