mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Big fixes and added method to standardize values in both latent and pixel space before feeding into the network. Target values were determined over huge generated regularization sets.
This commit is contained in:
@@ -21,6 +21,7 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
|
||||
@@ -680,9 +681,56 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
batch.latents = latents
|
||||
else:
|
||||
# normalize to
|
||||
if self.train_config.standardize_images:
|
||||
if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
|
||||
target_mean_list = [0.0002, -0.1034, -0.1879]
|
||||
target_std_list = [0.5436, 0.5116, 0.5033]
|
||||
else:
|
||||
target_mean_list = [-0.0739, -0.1597, -0.2380]
|
||||
target_std_list = [0.5623, 0.5295, 0.5347]
|
||||
# Mean: tensor([-0.0739, -0.1597, -0.2380])
|
||||
# Standard Deviation: tensor([0.5623, 0.5295, 0.5347])
|
||||
imgs_channel_mean = imgs.mean(dim=(2, 3), keepdim=True)
|
||||
imgs_channel_std = imgs.std(dim=(2, 3), keepdim=True)
|
||||
imgs = (imgs - imgs_channel_mean) / imgs_channel_std
|
||||
target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype)
|
||||
target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype)
|
||||
# expand them to match dim
|
||||
target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
||||
target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
||||
|
||||
imgs = imgs * target_std + target_mean
|
||||
batch.tensor = imgs
|
||||
|
||||
# show_tensors(imgs, 'imgs')
|
||||
|
||||
latents = self.sd.encode_images(imgs)
|
||||
batch.latents = latents
|
||||
|
||||
if self.train_config.standardize_latents:
|
||||
if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
|
||||
target_mean_list = [-0.1075, 0.0231, -0.0135, 0.2164]
|
||||
target_std_list = [0.8979, 0.7505, 0.9150, 0.7451]
|
||||
else:
|
||||
target_mean_list = [0.2949, -0.3188, 0.0807, 0.1929]
|
||||
target_std_list = [0.8560, 0.9629, 0.7778, 0.6719]
|
||||
|
||||
latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True)
|
||||
latents_channel_std = latents.std(dim=(2, 3), keepdim=True)
|
||||
latents = (latents - latents_channel_mean) / latents_channel_std
|
||||
target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype)
|
||||
target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype)
|
||||
# expand them to match dim
|
||||
target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
||||
target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
||||
|
||||
latents = latents * target_std + target_mean
|
||||
batch.latents = latents
|
||||
|
||||
# show_latents(latents, self.sd.vae, 'latents')
|
||||
|
||||
|
||||
if batch.unconditional_tensor is not None and batch.unconditional_latents is None:
|
||||
unconditional_imgs = batch.unconditional_tensor
|
||||
unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype)
|
||||
@@ -793,6 +841,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
latents = latents * self.train_config.latent_multiplier
|
||||
|
||||
# normalize latents to a mean of 0 and an std of 1
|
||||
# mean_zero_latents = latents - latents.mean()
|
||||
# latents = mean_zero_latents / mean_zero_latents.std()
|
||||
|
||||
|
||||
|
||||
if batch.unconditional_latents is not None:
|
||||
batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier
|
||||
|
||||
|
||||
@@ -257,6 +257,10 @@ class TrainConfig:
|
||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||
self.match_adapter_chance = 1.0
|
||||
|
||||
# standardize inputs to the meand std of the model knowledge
|
||||
self.standardize_images = kwargs.get('standardize_images', False)
|
||||
self.standardize_latents = kwargs.get('standardize_latents', False)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -1054,13 +1054,17 @@ class PoiFileItemDTOMixin:
|
||||
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
|
||||
max_scale_factor = max(width_scale_factor, height_scale_factor)
|
||||
|
||||
self.scale_to_width = int(initial_width * max_scale_factor)
|
||||
self.scale_to_height = int(initial_height * max_scale_factor)
|
||||
self.scale_to_width = math.ceil(initial_width * max_scale_factor)
|
||||
self.scale_to_height = math.ceil(initial_height * max_scale_factor)
|
||||
self.crop_width = bucket_resolution['width']
|
||||
self.crop_height = bucket_resolution['height']
|
||||
self.crop_x = int(poi_x * max_scale_factor)
|
||||
self.crop_y = int(poi_y * max_scale_factor)
|
||||
|
||||
if self.scale_to_width < self.crop_x + self.crop_width or self.scale_to_height < self.crop_y + self.crop_height:
|
||||
# todo look into this. This still happens sometimes
|
||||
print('size mismatch')
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user