Big fixes and added method to standardize values in both latent and pixel space before feeding into the network. Target values were determined over huge generated regularization sets.

This commit is contained in:
Jaret Burkett
2023-12-26 06:19:48 -07:00
parent 27ad79053e
commit d11ed7f66c
3 changed files with 64 additions and 2 deletions

View File

@@ -21,6 +21,7 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
@@ -680,9 +681,56 @@ class BaseSDTrainProcess(BaseTrainProcess):
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents
else:
# normalize to
if self.train_config.standardize_images:
if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
target_mean_list = [0.0002, -0.1034, -0.1879]
target_std_list = [0.5436, 0.5116, 0.5033]
else:
target_mean_list = [-0.0739, -0.1597, -0.2380]
target_std_list = [0.5623, 0.5295, 0.5347]
# Mean: tensor([-0.0739, -0.1597, -0.2380])
# Standard Deviation: tensor([0.5623, 0.5295, 0.5347])
imgs_channel_mean = imgs.mean(dim=(2, 3), keepdim=True)
imgs_channel_std = imgs.std(dim=(2, 3), keepdim=True)
imgs = (imgs - imgs_channel_mean) / imgs_channel_std
target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype)
target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype)
# expand them to match dim
target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)
target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3)
imgs = imgs * target_std + target_mean
batch.tensor = imgs
# show_tensors(imgs, 'imgs')
latents = self.sd.encode_images(imgs)
batch.latents = latents
if self.train_config.standardize_latents:
if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
target_mean_list = [-0.1075, 0.0231, -0.0135, 0.2164]
target_std_list = [0.8979, 0.7505, 0.9150, 0.7451]
else:
target_mean_list = [0.2949, -0.3188, 0.0807, 0.1929]
target_std_list = [0.8560, 0.9629, 0.7778, 0.6719]
latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True)
latents_channel_std = latents.std(dim=(2, 3), keepdim=True)
latents = (latents - latents_channel_mean) / latents_channel_std
target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype)
target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype)
# expand them to match dim
target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)
target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3)
latents = latents * target_std + target_mean
batch.latents = latents
# show_latents(latents, self.sd.vae, 'latents')
if batch.unconditional_tensor is not None and batch.unconditional_latents is None:
unconditional_imgs = batch.unconditional_tensor
unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype)
@@ -793,6 +841,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
latents = latents * self.train_config.latent_multiplier
# normalize latents to a mean of 0 and an std of 1
# mean_zero_latents = latents - latents.mean()
# latents = mean_zero_latents / mean_zero_latents.std()
if batch.unconditional_latents is not None:
batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier

View File

@@ -257,6 +257,10 @@ class TrainConfig:
if match_adapter_assist and self.match_adapter_chance == 0.0:
self.match_adapter_chance = 1.0
# standardize inputs to the meand std of the model knowledge
self.standardize_images = kwargs.get('standardize_images', False)
self.standardize_latents = kwargs.get('standardize_latents', False)
class ModelConfig:
def __init__(self, **kwargs):

View File

@@ -1054,13 +1054,17 @@ class PoiFileItemDTOMixin:
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
max_scale_factor = max(width_scale_factor, height_scale_factor)
self.scale_to_width = int(initial_width * max_scale_factor)
self.scale_to_height = int(initial_height * max_scale_factor)
self.scale_to_width = math.ceil(initial_width * max_scale_factor)
self.scale_to_height = math.ceil(initial_height * max_scale_factor)
self.crop_width = bucket_resolution['width']
self.crop_height = bucket_resolution['height']
self.crop_x = int(poi_x * max_scale_factor)
self.crop_y = int(poi_y * max_scale_factor)
if self.scale_to_width < self.crop_x + self.crop_width or self.scale_to_height < self.crop_y + self.crop_height:
# todo look into this. This still happens sometimes
print('size mismatch')
return True