diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index e1450593..ae0bb153 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -457,7 +457,7 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.target_norm_std: # seperate out the batch and channels pred_std = noise_pred.std([2, 3], keepdim=True) - norm_std_loss = torch.abs(1.0 - pred_std).mean() + norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean() loss = loss + norm_std_loss diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e0a40ef4..5a6287ec 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -24,7 +24,7 @@ from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.ema import ExponentialMovingAverage from toolkit.embedding import Embedding -from toolkit.image_utils import show_tensors, show_latents +from toolkit.image_utils import show_tensors, show_latents, reduce_contrast from toolkit.ip_adapter import IPAdapter from toolkit.lora_special import LoRASpecialNetwork from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \ @@ -811,7 +811,8 @@ class BaseSDTrainProcess(BaseTrainProcess): imgs = batch.tensor imgs = imgs.to(self.device_torch, dtype=dtype) if self.train_config.img_multiplier is not None: - imgs = imgs * self.train_config.img_multiplier + # do it ad contrast + imgs = reduce_contrast(imgs, self.train_config.img_multiplier) if batch.latents is not None: latents = batch.latents.to(self.device_torch, dtype=dtype) batch.latents = latents diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index af95e9b7..950d6de4 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -354,6 +354,7 @@ class TrainConfig: # adds an additional loss to the network to encourage it output a normalized standard deviation self.target_norm_std = kwargs.get('target_norm_std', None) + self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) class ModelConfig: diff --git a/toolkit/image_utils.py b/toolkit/image_utils.py index 3c1dbe39..9b9f306e 100644 --- a/toolkit/image_utils.py +++ b/toolkit/image_utils.py @@ -495,6 +495,19 @@ def on_exit(): cv2.destroyAllWindows() +def reduce_contrast(tensor, factor): + # Ensure factor is between 0 and 1 + factor = max(0, min(factor, 1)) + + # Calculate the mean of the tensor + mean = torch.mean(tensor) + + # Reduce contrast + adjusted_tensor = (tensor - mean) * factor + mean + + # Clip values to ensure they stay within -1 to 1 range + return torch.clamp(adjusted_tensor, -1.0, 1.0) + atexit.register(on_exit) if __name__ == "__main__":