mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-10 04:59:52 +00:00
Reworked bucket loader to scale buckets to pixels amounts not just minimum size. Makes the network more consistant
This commit is contained in:
@@ -60,12 +60,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# activate network if it exits
|
||||
with network:
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
embedding_list = []
|
||||
# embed the prompts
|
||||
for prompt in conditioned_prompts:
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
|
||||
if not grad_on_text_encoder:
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
# make sure we can import from the toolkit
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
import os
|
||||
import cv2
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
|
||||
from toolkit.image_utils import show_img
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from library.model_util import load_vae
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
import argparse
|
||||
@@ -12,6 +23,7 @@ import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dataset_folder', type=str, default='input')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
@@ -30,8 +42,29 @@ dataset_config = DatasetConfig(
|
||||
|
||||
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
|
||||
|
||||
|
||||
# run through an epoch ang check sizes
|
||||
for batch in dataloader:
|
||||
print(list(batch[0].shape))
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
big_img = torch.cat(chunks, dim=3)
|
||||
big_img = big_img.squeeze(0)
|
||||
|
||||
min_val = big_img.min()
|
||||
max_val = big_img.max()
|
||||
|
||||
big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
|
||||
show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
print('done')
|
||||
|
||||
@@ -326,14 +326,19 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
print(f" - Preprocessing image dimensions")
|
||||
bad_count = 0
|
||||
for file in tqdm(file_list):
|
||||
file_item = FileItemDTO(
|
||||
path=file,
|
||||
dataset_config=dataset_config
|
||||
)
|
||||
if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
|
||||
try:
|
||||
file_item = FileItemDTO(
|
||||
path=file,
|
||||
dataset_config=dataset_config
|
||||
)
|
||||
if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
|
||||
bad_count += 1
|
||||
else:
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {file}")
|
||||
print(e)
|
||||
bad_count += 1
|
||||
else:
|
||||
self.file_list.append(file_item)
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
print(f" - Found {bad_count} images that are too small")
|
||||
@@ -376,7 +381,7 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
return self._get_single_item(item)
|
||||
|
||||
|
||||
def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||
def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
|
||||
if dataset_options is None or len(dataset_options) == 0:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import TYPE_CHECKING, List, Dict, Union
|
||||
@@ -94,56 +95,56 @@ class BucketsMixin:
|
||||
bucket_tolerance = config.bucket_tolerance
|
||||
file_list: List['FileItemDTO'] = self.file_list
|
||||
|
||||
# make sure out resolution is divisible by bucket_tolerance
|
||||
if resolution % bucket_tolerance != 0:
|
||||
# reduce it to the nearest divisible number
|
||||
resolution = resolution - (resolution % bucket_tolerance)
|
||||
total_pixels = resolution * resolution
|
||||
|
||||
# for file_item in enumerate(file_list):
|
||||
for idx, file_item in enumerate(file_list):
|
||||
width = file_item.crop_width
|
||||
height = file_item.crop_height
|
||||
|
||||
# determine new size, smallest dimension should be equal to resolution
|
||||
# the other dimension should be the same ratio it is now (bigger)
|
||||
new_width = resolution
|
||||
new_height = resolution
|
||||
if width > height:
|
||||
# scale width to match new resolution,
|
||||
new_width = int(width * (resolution / height))
|
||||
file_item.crop_width = new_width
|
||||
file_item.scale_to_width = new_width
|
||||
file_item.crop_height = resolution
|
||||
file_item.scale_to_height = resolution
|
||||
# make sure new_width is divisible by bucket_tolerance
|
||||
# determine new resolution to have the same number of pixels
|
||||
current_pixels = width * height
|
||||
if current_pixels == total_pixels:
|
||||
# no change
|
||||
continue
|
||||
|
||||
aspect_ratio = width / height
|
||||
new_height = int(math.sqrt(total_pixels / aspect_ratio))
|
||||
new_width = int(aspect_ratio * new_height)
|
||||
|
||||
# increase smallest one to be divisible by bucket_tolerance and increase the other to match
|
||||
if new_width < new_height:
|
||||
# increase width
|
||||
if new_width % bucket_tolerance != 0:
|
||||
# reduce it to the nearest divisible number
|
||||
reduction = new_width % bucket_tolerance
|
||||
file_item.crop_width = new_width - reduction
|
||||
new_width = file_item.crop_width
|
||||
# adjust the new x position so we evenly crop
|
||||
file_item.crop_x = int(file_item.crop_x + (reduction / 2))
|
||||
elif height > width:
|
||||
# scale height to match new resolution
|
||||
new_height = int(height * (resolution / width))
|
||||
file_item.crop_height = new_height
|
||||
file_item.scale_to_height = new_height
|
||||
file_item.scale_to_width = resolution
|
||||
file_item.crop_width = resolution
|
||||
# make sure new_height is divisible by bucket_tolerance
|
||||
if new_height % bucket_tolerance != 0:
|
||||
# reduce it to the nearest divisible number
|
||||
reduction = new_height % bucket_tolerance
|
||||
file_item.crop_height = new_height - reduction
|
||||
new_height = file_item.crop_height
|
||||
# adjust the new x position so we evenly crop
|
||||
file_item.crop_y = int(file_item.crop_y + (reduction / 2))
|
||||
crop_amount = new_width % bucket_tolerance
|
||||
new_width = new_width + (bucket_tolerance - crop_amount)
|
||||
new_height = int(new_width / aspect_ratio)
|
||||
else:
|
||||
# square image
|
||||
file_item.crop_height = resolution
|
||||
file_item.scale_to_height = resolution
|
||||
file_item.scale_to_width = resolution
|
||||
file_item.crop_width = resolution
|
||||
# increase height
|
||||
if new_height % bucket_tolerance != 0:
|
||||
crop_amount = new_height % bucket_tolerance
|
||||
new_height = new_height + (bucket_tolerance - crop_amount)
|
||||
new_width = int(aspect_ratio * new_height)
|
||||
|
||||
# Ensure that the total number of pixels remains the same.
|
||||
# assert new_width * new_height == total_pixels
|
||||
|
||||
file_item.scale_to_width = new_width
|
||||
file_item.scale_to_height = new_height
|
||||
file_item.crop_width = new_width
|
||||
file_item.crop_height = new_height
|
||||
# make sure it is divisible by bucket_tolerance, decrease if not
|
||||
if new_width % bucket_tolerance != 0:
|
||||
crop_amount = new_width % bucket_tolerance
|
||||
file_item.crop_width = new_width - crop_amount
|
||||
else:
|
||||
file_item.crop_width = new_width
|
||||
|
||||
if new_height % bucket_tolerance != 0:
|
||||
crop_amount = new_height % bucket_tolerance
|
||||
file_item.crop_height = new_height - crop_amount
|
||||
else:
|
||||
file_item.crop_height = new_height
|
||||
|
||||
# check if bucket exists, if not, create it
|
||||
bucket_key = f'{new_width}x{new_height}'
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
# ref https://github.com/scardine/image_size/blob/master/get_image_size.py
|
||||
import atexit
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import io
|
||||
import struct
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
|
||||
|
||||
|
||||
@@ -112,7 +116,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
|
||||
width = int(w)
|
||||
height = int(h)
|
||||
elif ((size >= 24) and data.startswith(b'\211PNG\r\n\032\n')
|
||||
and (data[12:16] == b'IHDR')):
|
||||
and (data[12:16] == b'IHDR')):
|
||||
# PNGs
|
||||
imgtype = PNG
|
||||
w, h = struct.unpack(">LL", data[16:24])
|
||||
@@ -190,7 +194,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
|
||||
9: (4, boChar + "l"), # SLONG
|
||||
10: (8, boChar + "ll"), # SRATIONAL
|
||||
11: (4, boChar + "f"), # FLOAT
|
||||
12: (8, boChar + "d") # DOUBLE
|
||||
12: (8, boChar + "d") # DOUBLE
|
||||
}
|
||||
ifdOffset = struct.unpack(boChar + "L", data[4:8])[0]
|
||||
try:
|
||||
@@ -206,7 +210,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
|
||||
input.seek(entryOffset)
|
||||
tag = input.read(2)
|
||||
tag = struct.unpack(boChar + "H", tag)[0]
|
||||
if(tag == 256 or tag == 257):
|
||||
if (tag == 256 or tag == 257):
|
||||
# if type indicates that value fits into 4 bytes, value
|
||||
# offset is not an offset but value itself
|
||||
type = input.read(2)
|
||||
@@ -229,7 +233,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
|
||||
except Exception as e:
|
||||
raise UnknownImageFormat(str(e))
|
||||
elif size >= 2:
|
||||
# see http://en.wikipedia.org/wiki/ICO_(file_format)
|
||||
# see http://en.wikipedia.org/wiki/ICO_(file_format)
|
||||
imgtype = 'ICO'
|
||||
input.seek(0)
|
||||
reserved = input.read(2)
|
||||
@@ -350,13 +354,13 @@ def main(argv=None):
|
||||
|
||||
prs.add_option('-v', '--verbose',
|
||||
dest='verbose',
|
||||
action='store_true',)
|
||||
action='store_true', )
|
||||
prs.add_option('-q', '--quiet',
|
||||
dest='quiet',
|
||||
action='store_true',)
|
||||
action='store_true', )
|
||||
prs.add_option('-t', '--test',
|
||||
dest='run_tests',
|
||||
action='store_true',)
|
||||
action='store_true', )
|
||||
|
||||
argv = list(argv) if argv is not None else sys.argv[1:]
|
||||
(opts, args) = prs.parse_args(args=argv)
|
||||
@@ -417,6 +421,36 @@ def main(argv=None):
|
||||
return EX_OK
|
||||
|
||||
|
||||
is_window_shown = False
|
||||
|
||||
|
||||
def show_img(img):
|
||||
global is_window_shown
|
||||
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
cv2.imshow('AI Toolkit', img[:, :, ::-1])
|
||||
k = cv2.waitKey(10) & 0xFF
|
||||
if k == 27: # Esc key to stop
|
||||
print('\nESC pressed, stopping')
|
||||
raise KeyboardInterrupt
|
||||
# show again to initialize the window if first
|
||||
if not is_window_shown:
|
||||
cv2.imshow('AI Toolkit', img[:, :, ::-1])
|
||||
k = cv2.waitKey(10) & 0xFF
|
||||
if k == 27: # Esc key to stop
|
||||
print('\nESC pressed, stopping')
|
||||
raise KeyboardInterrupt
|
||||
is_window_shown = True
|
||||
|
||||
|
||||
def on_exit():
|
||||
if is_window_shown:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
atexit.register(on_exit)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(main(argv=sys.argv[1:]))
|
||||
|
||||
sys.exit(main(argv=sys.argv[1:]))
|
||||
|
||||
@@ -711,7 +711,7 @@ class StableDiffusion:
|
||||
|
||||
raise ValueError(f"Unknown weight name: {name}")
|
||||
|
||||
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
|
||||
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False):
|
||||
return inject_trigger_into_prompt(
|
||||
prompt,
|
||||
trigger=trigger,
|
||||
|
||||
Reference in New Issue
Block a user