Reworked bucket loader to scale buckets to pixels amounts not just minimum size. Makes the network more consistant

This commit is contained in:
Jaret Burkett
2023-08-30 14:52:12 -06:00
parent d401348c2e
commit 33267e117c
6 changed files with 137 additions and 69 deletions

View File

@@ -60,12 +60,7 @@ class SDTrainer(BaseSDTrainProcess):
# activate network if it exits
with network:
with torch.set_grad_enabled(grad_on_text_encoder):
embedding_list = []
# embed the prompts
for prompt in conditioned_prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
if not grad_on_text_encoder:
# detach the embeddings
conditional_embeds = conditional_embeds.detach()

View File

@@ -1,10 +1,21 @@
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm
# make sure we can import from the toolkit
import time
import numpy as np
import torch
from torchvision import transforms
import sys
import os
import cv2
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.paths import SD_SCRIPTS_ROOT
from toolkit.image_utils import show_img
sys.path.append(SD_SCRIPTS_ROOT)
from library.model_util import load_vae
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets
from toolkit.config_modules import DatasetConfig
import argparse
@@ -12,6 +23,7 @@ import argparse
parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input')
args = parser.parse_args()
dataset_folder = args.dataset_folder
@@ -30,8 +42,29 @@ dataset_config = DatasetConfig(
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
# run through an epoch ang check sizes
for batch in dataloader:
print(list(batch[0].shape))
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side
big_img = torch.cat(chunks, dim=3)
big_img = big_img.squeeze(0)
min_val = big_img.min()
max_val = big_img.max()
big_img = (big_img / 2 + 0.5).clamp(0, 1)
# convert to image
img = transforms.ToPILImage()(big_img)
show_img(img)
time.sleep(1.0)
cv2.destroyAllWindows()
print('done')

View File

@@ -326,14 +326,19 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
print(f" - Preprocessing image dimensions")
bad_count = 0
for file in tqdm(file_list):
file_item = FileItemDTO(
path=file,
dataset_config=dataset_config
)
if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
try:
file_item = FileItemDTO(
path=file,
dataset_config=dataset_config
)
if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
bad_count += 1
else:
self.file_list.append(file_item)
except Exception as e:
print(f"Error processing image: {file}")
print(e)
bad_count += 1
else:
self.file_list.append(file_item)
print(f" - Found {len(self.file_list)} images")
print(f" - Found {bad_count} images that are too small")
@@ -376,7 +381,7 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
return self._get_single_item(item)
def get_dataloader_from_datasets(dataset_options, batch_size=1):
def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader:
if dataset_options is None or len(dataset_options) == 0:
return None

View File

@@ -1,3 +1,4 @@
import math
import os
import random
from typing import TYPE_CHECKING, List, Dict, Union
@@ -94,56 +95,56 @@ class BucketsMixin:
bucket_tolerance = config.bucket_tolerance
file_list: List['FileItemDTO'] = self.file_list
# make sure out resolution is divisible by bucket_tolerance
if resolution % bucket_tolerance != 0:
# reduce it to the nearest divisible number
resolution = resolution - (resolution % bucket_tolerance)
total_pixels = resolution * resolution
# for file_item in enumerate(file_list):
for idx, file_item in enumerate(file_list):
width = file_item.crop_width
height = file_item.crop_height
# determine new size, smallest dimension should be equal to resolution
# the other dimension should be the same ratio it is now (bigger)
new_width = resolution
new_height = resolution
if width > height:
# scale width to match new resolution,
new_width = int(width * (resolution / height))
file_item.crop_width = new_width
file_item.scale_to_width = new_width
file_item.crop_height = resolution
file_item.scale_to_height = resolution
# make sure new_width is divisible by bucket_tolerance
# determine new resolution to have the same number of pixels
current_pixels = width * height
if current_pixels == total_pixels:
# no change
continue
aspect_ratio = width / height
new_height = int(math.sqrt(total_pixels / aspect_ratio))
new_width = int(aspect_ratio * new_height)
# increase smallest one to be divisible by bucket_tolerance and increase the other to match
if new_width < new_height:
# increase width
if new_width % bucket_tolerance != 0:
# reduce it to the nearest divisible number
reduction = new_width % bucket_tolerance
file_item.crop_width = new_width - reduction
new_width = file_item.crop_width
# adjust the new x position so we evenly crop
file_item.crop_x = int(file_item.crop_x + (reduction / 2))
elif height > width:
# scale height to match new resolution
new_height = int(height * (resolution / width))
file_item.crop_height = new_height
file_item.scale_to_height = new_height
file_item.scale_to_width = resolution
file_item.crop_width = resolution
# make sure new_height is divisible by bucket_tolerance
if new_height % bucket_tolerance != 0:
# reduce it to the nearest divisible number
reduction = new_height % bucket_tolerance
file_item.crop_height = new_height - reduction
new_height = file_item.crop_height
# adjust the new x position so we evenly crop
file_item.crop_y = int(file_item.crop_y + (reduction / 2))
crop_amount = new_width % bucket_tolerance
new_width = new_width + (bucket_tolerance - crop_amount)
new_height = int(new_width / aspect_ratio)
else:
# square image
file_item.crop_height = resolution
file_item.scale_to_height = resolution
file_item.scale_to_width = resolution
file_item.crop_width = resolution
# increase height
if new_height % bucket_tolerance != 0:
crop_amount = new_height % bucket_tolerance
new_height = new_height + (bucket_tolerance - crop_amount)
new_width = int(aspect_ratio * new_height)
# Ensure that the total number of pixels remains the same.
# assert new_width * new_height == total_pixels
file_item.scale_to_width = new_width
file_item.scale_to_height = new_height
file_item.crop_width = new_width
file_item.crop_height = new_height
# make sure it is divisible by bucket_tolerance, decrease if not
if new_width % bucket_tolerance != 0:
crop_amount = new_width % bucket_tolerance
file_item.crop_width = new_width - crop_amount
else:
file_item.crop_width = new_width
if new_height % bucket_tolerance != 0:
crop_amount = new_height % bucket_tolerance
file_item.crop_height = new_height - crop_amount
else:
file_item.crop_height = new_height
# check if bucket exists, if not, create it
bucket_key = f'{new_width}x{new_height}'

View File

@@ -1,10 +1,14 @@
# ref https://github.com/scardine/image_size/blob/master/get_image_size.py
import atexit
import collections
import json
import os
import io
import struct
import cv2
import numpy as np
FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
@@ -112,7 +116,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
width = int(w)
height = int(h)
elif ((size >= 24) and data.startswith(b'\211PNG\r\n\032\n')
and (data[12:16] == b'IHDR')):
and (data[12:16] == b'IHDR')):
# PNGs
imgtype = PNG
w, h = struct.unpack(">LL", data[16:24])
@@ -190,7 +194,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
9: (4, boChar + "l"), # SLONG
10: (8, boChar + "ll"), # SRATIONAL
11: (4, boChar + "f"), # FLOAT
12: (8, boChar + "d") # DOUBLE
12: (8, boChar + "d") # DOUBLE
}
ifdOffset = struct.unpack(boChar + "L", data[4:8])[0]
try:
@@ -206,7 +210,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
input.seek(entryOffset)
tag = input.read(2)
tag = struct.unpack(boChar + "H", tag)[0]
if(tag == 256 or tag == 257):
if (tag == 256 or tag == 257):
# if type indicates that value fits into 4 bytes, value
# offset is not an offset but value itself
type = input.read(2)
@@ -229,7 +233,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None):
except Exception as e:
raise UnknownImageFormat(str(e))
elif size >= 2:
# see http://en.wikipedia.org/wiki/ICO_(file_format)
# see http://en.wikipedia.org/wiki/ICO_(file_format)
imgtype = 'ICO'
input.seek(0)
reserved = input.read(2)
@@ -350,13 +354,13 @@ def main(argv=None):
prs.add_option('-v', '--verbose',
dest='verbose',
action='store_true',)
action='store_true', )
prs.add_option('-q', '--quiet',
dest='quiet',
action='store_true',)
action='store_true', )
prs.add_option('-t', '--test',
dest='run_tests',
action='store_true',)
action='store_true', )
argv = list(argv) if argv is not None else sys.argv[1:]
(opts, args) = prs.parse_args(args=argv)
@@ -417,6 +421,36 @@ def main(argv=None):
return EX_OK
is_window_shown = False
def show_img(img):
global is_window_shown
img = np.clip(img, 0, 255).astype(np.uint8)
cv2.imshow('AI Toolkit', img[:, :, ::-1])
k = cv2.waitKey(10) & 0xFF
if k == 27: # Esc key to stop
print('\nESC pressed, stopping')
raise KeyboardInterrupt
# show again to initialize the window if first
if not is_window_shown:
cv2.imshow('AI Toolkit', img[:, :, ::-1])
k = cv2.waitKey(10) & 0xFF
if k == 27: # Esc key to stop
print('\nESC pressed, stopping')
raise KeyboardInterrupt
is_window_shown = True
def on_exit():
if is_window_shown:
cv2.destroyAllWindows()
atexit.register(on_exit)
if __name__ == "__main__":
import sys
sys.exit(main(argv=sys.argv[1:]))
sys.exit(main(argv=sys.argv[1:]))

View File

@@ -711,7 +711,7 @@ class StableDiffusion:
raise ValueError(f"Unknown weight name: {name}")
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False):
return inject_trigger_into_prompt(
prompt,
trigger=trigger,