allow flipping for point of interesting autocropping. allow num repeats. Fixed some bugs with new free u

This commit is contained in:
Jaret Burkett
2023-10-12 21:02:47 -06:00
parent 4e3b2c2569
commit 38e441a29c
5 changed files with 27 additions and 4 deletions

View File

@@ -928,6 +928,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
if self.train_config.free_u:
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
self.progress_bar.unpause()
with torch.no_grad():
# if is even step and we have a reg dataset, use that
@@ -1000,6 +1002,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if is_sample_step:
self.progress_bar.pause()
# print above the progress bar
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
self.sample(self.step_num)
self.progress_bar.unpause()
@@ -1047,6 +1051,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# flush()
self.progress_bar.close()
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
self.sample(self.step_num + 1)
print("")
self.save()

View File

@@ -120,6 +120,7 @@ class TrainConfig:
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
self.start_step = kwargs.get('start_step', None)
self.free_u = kwargs.get('free_u', False)
class ModelConfig:
@@ -239,7 +240,7 @@ class DatasetConfig:
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
# cache latents will store them in memory
self.cache_latents: bool = kwargs.get('cache_latents', False)
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory

View File

@@ -377,6 +377,10 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
# keys are file paths
file_list = list(self.caption_dict.keys())
if self.dataset_config.num_repeats > 1:
# repeat the list
file_list = file_list * self.dataset_config.num_repeats
# this might take a while
print(f" - Preprocessing image dimensions")
bad_count = 0

View File

@@ -156,8 +156,9 @@ class BucketsMixin:
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
max_scale_factor = max(width_scale_factor, height_scale_factor)
file_item.scale_to_width = int(width * max_scale_factor)
file_item.scale_to_height = int(height * max_scale_factor)
# round up
file_item.scale_to_width = int(math.ceil(width * max_scale_factor))
file_item.scale_to_height = int(math.ceil(height * max_scale_factor))
file_item.crop_height = bucket_resolution["height"]
file_item.crop_width = bucket_resolution["width"]
@@ -294,7 +295,7 @@ class ImageProcessingDTOMixin:
self.load_mask_image()
return
try:
img = Image.open(self.mask_path)
img = Image.open(self.path)
img = exif_transpose(img)
except Exception as e:
print(f"Error: {e}")
@@ -583,6 +584,14 @@ class PoiFileItemDTOMixin:
self.poi_width = int(poi['width'])
self.poi_height = int(poi['height'])
# handle flipping
if kwargs.get('flip_x', False):
# flip the poi
self.poi_x = self.width - self.poi_x - self.poi_width
if kwargs.get('flip_y', False):
# flip the poi
self.poi_y = self.height - self.poi_y - self.poi_height
def setup_poi_bucket(self: 'FileItemDTO'):
# we are using poi, so we need to calculate the bucket based on the poi

View File

@@ -5,6 +5,9 @@ import os
import time
from typing import TYPE_CHECKING, Union
import sys
from torch.cuda.amp import GradScaler
from toolkit.paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT)