mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
allow flipping for point of interesting autocropping. allow num repeats. Fixed some bugs with new free u
This commit is contained in:
@@ -928,6 +928,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
flush()
|
flush()
|
||||||
# self.step_num = 0
|
# self.step_num = 0
|
||||||
for step in range(self.step_num, self.train_config.steps):
|
for step in range(self.step_num, self.train_config.steps):
|
||||||
|
if self.train_config.free_u:
|
||||||
|
self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
|
||||||
self.progress_bar.unpause()
|
self.progress_bar.unpause()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# if is even step and we have a reg dataset, use that
|
# if is even step and we have a reg dataset, use that
|
||||||
@@ -1000,6 +1002,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if is_sample_step:
|
if is_sample_step:
|
||||||
self.progress_bar.pause()
|
self.progress_bar.pause()
|
||||||
# print above the progress bar
|
# print above the progress bar
|
||||||
|
if self.train_config.free_u:
|
||||||
|
self.sd.pipeline.disable_freeu()
|
||||||
self.sample(self.step_num)
|
self.sample(self.step_num)
|
||||||
self.progress_bar.unpause()
|
self.progress_bar.unpause()
|
||||||
|
|
||||||
@@ -1047,6 +1051,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# flush()
|
# flush()
|
||||||
|
|
||||||
self.progress_bar.close()
|
self.progress_bar.close()
|
||||||
|
if self.train_config.free_u:
|
||||||
|
self.sd.pipeline.disable_freeu()
|
||||||
self.sample(self.step_num + 1)
|
self.sample(self.step_num + 1)
|
||||||
print("")
|
print("")
|
||||||
self.save()
|
self.save()
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ class TrainConfig:
|
|||||||
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
|
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
|
||||||
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
|
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
|
||||||
self.start_step = kwargs.get('start_step', None)
|
self.start_step = kwargs.get('start_step', None)
|
||||||
|
self.free_u = kwargs.get('free_u', False)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@@ -239,7 +240,7 @@ class DatasetConfig:
|
|||||||
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
|
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
|
||||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
||||||
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
|
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||||
|
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
|
||||||
# cache latents will store them in memory
|
# cache latents will store them in memory
|
||||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||||
|
|||||||
@@ -377,6 +377,10 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
|||||||
# keys are file paths
|
# keys are file paths
|
||||||
file_list = list(self.caption_dict.keys())
|
file_list = list(self.caption_dict.keys())
|
||||||
|
|
||||||
|
if self.dataset_config.num_repeats > 1:
|
||||||
|
# repeat the list
|
||||||
|
file_list = file_list * self.dataset_config.num_repeats
|
||||||
|
|
||||||
# this might take a while
|
# this might take a while
|
||||||
print(f" - Preprocessing image dimensions")
|
print(f" - Preprocessing image dimensions")
|
||||||
bad_count = 0
|
bad_count = 0
|
||||||
|
|||||||
@@ -156,8 +156,9 @@ class BucketsMixin:
|
|||||||
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
|
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
|
||||||
max_scale_factor = max(width_scale_factor, height_scale_factor)
|
max_scale_factor = max(width_scale_factor, height_scale_factor)
|
||||||
|
|
||||||
file_item.scale_to_width = int(width * max_scale_factor)
|
# round up
|
||||||
file_item.scale_to_height = int(height * max_scale_factor)
|
file_item.scale_to_width = int(math.ceil(width * max_scale_factor))
|
||||||
|
file_item.scale_to_height = int(math.ceil(height * max_scale_factor))
|
||||||
|
|
||||||
file_item.crop_height = bucket_resolution["height"]
|
file_item.crop_height = bucket_resolution["height"]
|
||||||
file_item.crop_width = bucket_resolution["width"]
|
file_item.crop_width = bucket_resolution["width"]
|
||||||
@@ -294,7 +295,7 @@ class ImageProcessingDTOMixin:
|
|||||||
self.load_mask_image()
|
self.load_mask_image()
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
img = Image.open(self.mask_path)
|
img = Image.open(self.path)
|
||||||
img = exif_transpose(img)
|
img = exif_transpose(img)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
@@ -583,6 +584,14 @@ class PoiFileItemDTOMixin:
|
|||||||
self.poi_width = int(poi['width'])
|
self.poi_width = int(poi['width'])
|
||||||
self.poi_height = int(poi['height'])
|
self.poi_height = int(poi['height'])
|
||||||
|
|
||||||
|
# handle flipping
|
||||||
|
if kwargs.get('flip_x', False):
|
||||||
|
# flip the poi
|
||||||
|
self.poi_x = self.width - self.poi_x - self.poi_width
|
||||||
|
if kwargs.get('flip_y', False):
|
||||||
|
# flip the poi
|
||||||
|
self.poi_y = self.height - self.poi_y - self.poi_height
|
||||||
|
|
||||||
def setup_poi_bucket(self: 'FileItemDTO'):
|
def setup_poi_bucket(self: 'FileItemDTO'):
|
||||||
# we are using poi, so we need to calculate the bucket based on the poi
|
# we are using poi, so we need to calculate the bucket based on the poi
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ import os
|
|||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||||
|
|
||||||
sys.path.append(SD_SCRIPTS_ROOT)
|
sys.path.append(SD_SCRIPTS_ROOT)
|
||||||
|
|||||||
Reference in New Issue
Block a user