From e4ae97e790059473df4d8a9fa53354a0ddde9c7d Mon Sep 17 00:00:00 2001 From: max Date: Thu, 18 Sep 2025 01:11:19 +0300 Subject: [PATCH 1/2] add dataset-level distillation-style regularization --- extensions_built_in/sd_trainer/SDTrainer.py | 10 +++++----- toolkit/config_modules.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f1d4e829..3001ea3c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1154,6 +1154,11 @@ class SDTrainer(BaseSDTrainProcess): has_clip_image_embeds = True # we are caching embeds, handle that differently has_clip_image = False + # do prior pred if prior regularization batch + do_reg_prior = False + if any([batch.file_items[idx].prior_reg for idx in range(len(batch.file_items))]): + do_reg_prior = True + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: raise ValueError( "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") @@ -1648,11 +1653,6 @@ class SDTrainer(BaseSDTrainProcess): prior_pred = None - do_reg_prior = False - # if is_reg and (self.network is not None or self.adapter is not None): - # # we are doing a reg image and we have a network or adapter - # do_reg_prior = True - do_inverted_masked_prior = False if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: do_inverted_masked_prior = True diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 2526d86e..085fe7c1 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -806,6 +806,7 @@ class DatasetConfig: self.buckets: bool = kwargs.get('buckets', True) self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64) self.is_reg: bool = kwargs.get('is_reg', False) + self.prior_reg: bool = kwargs.get('prior_reg', False) self.network_weight: float = float(kwargs.get('network_weight', 1.0)) self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0)) self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) From e27e229b3664fa5594ae79707593dd6b0d9a31a1 Mon Sep 17 00:00:00 2001 From: squewel Date: Thu, 18 Sep 2025 02:09:39 +0300 Subject: [PATCH 2/2] add prior_reg flag to FileItemDTO --- toolkit/data_transfer_object/data_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index b6863663..a7ed3759 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -121,6 +121,7 @@ class FileItemDTO( self.network_weight: float = self.dataset_config.network_weight self.is_reg = self.dataset_config.is_reg + self.prior_reg = self.dataset_config.prior_reg self.tensor: Union[torch.Tensor, None] = None def cleanup(self):