mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
allow flipping for point of interesting autocropping. allow num repeats. Fixed some bugs with new free u
This commit is contained in:
@@ -928,6 +928,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
|
||||
self.progress_bar.unpause()
|
||||
with torch.no_grad():
|
||||
# if is even step and we have a reg dataset, use that
|
||||
@@ -1000,6 +1002,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if is_sample_step:
|
||||
self.progress_bar.pause()
|
||||
# print above the progress bar
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
self.sample(self.step_num)
|
||||
self.progress_bar.unpause()
|
||||
|
||||
@@ -1047,6 +1051,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# flush()
|
||||
|
||||
self.progress_bar.close()
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
self.sample(self.step_num + 1)
|
||||
print("")
|
||||
self.save()
|
||||
|
||||
@@ -120,6 +120,7 @@ class TrainConfig:
|
||||
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
|
||||
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
|
||||
self.start_step = kwargs.get('start_step', None)
|
||||
self.free_u = kwargs.get('free_u', False)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@@ -239,7 +240,7 @@ class DatasetConfig:
|
||||
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
|
||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
||||
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||
|
||||
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
|
||||
# cache latents will store them in memory
|
||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||
|
||||
@@ -377,6 +377,10 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
# keys are file paths
|
||||
file_list = list(self.caption_dict.keys())
|
||||
|
||||
if self.dataset_config.num_repeats > 1:
|
||||
# repeat the list
|
||||
file_list = file_list * self.dataset_config.num_repeats
|
||||
|
||||
# this might take a while
|
||||
print(f" - Preprocessing image dimensions")
|
||||
bad_count = 0
|
||||
|
||||
@@ -156,8 +156,9 @@ class BucketsMixin:
|
||||
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
|
||||
max_scale_factor = max(width_scale_factor, height_scale_factor)
|
||||
|
||||
file_item.scale_to_width = int(width * max_scale_factor)
|
||||
file_item.scale_to_height = int(height * max_scale_factor)
|
||||
# round up
|
||||
file_item.scale_to_width = int(math.ceil(width * max_scale_factor))
|
||||
file_item.scale_to_height = int(math.ceil(height * max_scale_factor))
|
||||
|
||||
file_item.crop_height = bucket_resolution["height"]
|
||||
file_item.crop_width = bucket_resolution["width"]
|
||||
@@ -294,7 +295,7 @@ class ImageProcessingDTOMixin:
|
||||
self.load_mask_image()
|
||||
return
|
||||
try:
|
||||
img = Image.open(self.mask_path)
|
||||
img = Image.open(self.path)
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
@@ -583,6 +584,14 @@ class PoiFileItemDTOMixin:
|
||||
self.poi_width = int(poi['width'])
|
||||
self.poi_height = int(poi['height'])
|
||||
|
||||
# handle flipping
|
||||
if kwargs.get('flip_x', False):
|
||||
# flip the poi
|
||||
self.poi_x = self.width - self.poi_x - self.poi_width
|
||||
if kwargs.get('flip_y', False):
|
||||
# flip the poi
|
||||
self.poi_y = self.height - self.poi_y - self.poi_height
|
||||
|
||||
def setup_poi_bucket(self: 'FileItemDTO'):
|
||||
# we are using poi, so we need to calculate the bucket based on the poi
|
||||
|
||||
|
||||
@@ -5,6 +5,9 @@ import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Union
|
||||
import sys
|
||||
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
Reference in New Issue
Block a user