allow flipping for point of interesting autocropping. allow num repeats. Fixed some bugs with new free u

This commit is contained in:
Jaret Burkett
2023-10-12 21:02:47 -06:00
parent 4e3b2c2569
commit 38e441a29c
5 changed files with 27 additions and 4 deletions

View File

@@ -928,6 +928,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush() flush()
# self.step_num = 0 # self.step_num = 0
for step in range(self.step_num, self.train_config.steps): for step in range(self.step_num, self.train_config.steps):
if self.train_config.free_u:
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
self.progress_bar.unpause() self.progress_bar.unpause()
with torch.no_grad(): with torch.no_grad():
# if is even step and we have a reg dataset, use that # if is even step and we have a reg dataset, use that
@@ -1000,6 +1002,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if is_sample_step: if is_sample_step:
self.progress_bar.pause() self.progress_bar.pause()
# print above the progress bar # print above the progress bar
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
self.sample(self.step_num) self.sample(self.step_num)
self.progress_bar.unpause() self.progress_bar.unpause()
@@ -1047,6 +1051,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# flush() # flush()
self.progress_bar.close() self.progress_bar.close()
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
self.sample(self.step_num + 1) self.sample(self.step_num + 1)
print("") print("")
self.save() self.save()

View File

@@ -120,6 +120,7 @@ class TrainConfig:
self.merge_network_on_save = kwargs.get('merge_network_on_save', False) self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0) self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
self.start_step = kwargs.get('start_step', None) self.start_step = kwargs.get('start_step', None)
self.free_u = kwargs.get('free_u', False)
class ModelConfig: class ModelConfig:
@@ -239,7 +240,7 @@ class DatasetConfig:
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black) self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1 self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
# cache latents will store them in memory # cache latents will store them in memory
self.cache_latents: bool = kwargs.get('cache_latents', False) self.cache_latents: bool = kwargs.get('cache_latents', False)
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory

View File

@@ -377,6 +377,10 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
# keys are file paths # keys are file paths
file_list = list(self.caption_dict.keys()) file_list = list(self.caption_dict.keys())
if self.dataset_config.num_repeats > 1:
# repeat the list
file_list = file_list * self.dataset_config.num_repeats
# this might take a while # this might take a while
print(f" - Preprocessing image dimensions") print(f" - Preprocessing image dimensions")
bad_count = 0 bad_count = 0

View File

@@ -156,8 +156,9 @@ class BucketsMixin:
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
max_scale_factor = max(width_scale_factor, height_scale_factor) max_scale_factor = max(width_scale_factor, height_scale_factor)
file_item.scale_to_width = int(width * max_scale_factor) # round up
file_item.scale_to_height = int(height * max_scale_factor) file_item.scale_to_width = int(math.ceil(width * max_scale_factor))
file_item.scale_to_height = int(math.ceil(height * max_scale_factor))
file_item.crop_height = bucket_resolution["height"] file_item.crop_height = bucket_resolution["height"]
file_item.crop_width = bucket_resolution["width"] file_item.crop_width = bucket_resolution["width"]
@@ -294,7 +295,7 @@ class ImageProcessingDTOMixin:
self.load_mask_image() self.load_mask_image()
return return
try: try:
img = Image.open(self.mask_path) img = Image.open(self.path)
img = exif_transpose(img) img = exif_transpose(img)
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
@@ -583,6 +584,14 @@ class PoiFileItemDTOMixin:
self.poi_width = int(poi['width']) self.poi_width = int(poi['width'])
self.poi_height = int(poi['height']) self.poi_height = int(poi['height'])
# handle flipping
if kwargs.get('flip_x', False):
# flip the poi
self.poi_x = self.width - self.poi_x - self.poi_width
if kwargs.get('flip_y', False):
# flip the poi
self.poi_y = self.height - self.poi_y - self.poi_height
def setup_poi_bucket(self: 'FileItemDTO'): def setup_poi_bucket(self: 'FileItemDTO'):
# we are using poi, so we need to calculate the bucket based on the poi # we are using poi, so we need to calculate the bucket based on the poi

View File

@@ -5,6 +5,9 @@ import os
import time import time
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import sys import sys
from torch.cuda.amp import GradScaler
from toolkit.paths import SD_SCRIPTS_ROOT from toolkit.paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT) sys.path.append(SD_SCRIPTS_ROOT)