mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Merge pull request #426 from squewel/prior_reg
Dataset-level prior regularization
This commit is contained in:
@@ -1154,6 +1154,11 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
||||||
has_clip_image = False
|
has_clip_image = False
|
||||||
|
|
||||||
|
# do prior pred if prior regularization batch
|
||||||
|
do_reg_prior = False
|
||||||
|
if any([batch.file_items[idx].prior_reg for idx in range(len(batch.file_items))]):
|
||||||
|
do_reg_prior = True
|
||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||||
@@ -1648,11 +1653,6 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
prior_pred = None
|
prior_pred = None
|
||||||
|
|
||||||
do_reg_prior = False
|
|
||||||
# if is_reg and (self.network is not None or self.adapter is not None):
|
|
||||||
# # we are doing a reg image and we have a network or adapter
|
|
||||||
# do_reg_prior = True
|
|
||||||
|
|
||||||
do_inverted_masked_prior = False
|
do_inverted_masked_prior = False
|
||||||
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
||||||
do_inverted_masked_prior = True
|
do_inverted_masked_prior = True
|
||||||
|
|||||||
@@ -813,6 +813,7 @@ class DatasetConfig:
|
|||||||
self.buckets: bool = kwargs.get('buckets', True)
|
self.buckets: bool = kwargs.get('buckets', True)
|
||||||
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
|
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
|
||||||
self.is_reg: bool = kwargs.get('is_reg', False)
|
self.is_reg: bool = kwargs.get('is_reg', False)
|
||||||
|
self.prior_reg: bool = kwargs.get('prior_reg', False)
|
||||||
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
|
self.network_weight: float = float(kwargs.get('network_weight', 1.0))
|
||||||
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
||||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ class FileItemDTO(
|
|||||||
|
|
||||||
self.network_weight: float = self.dataset_config.network_weight
|
self.network_weight: float = self.dataset_config.network_weight
|
||||||
self.is_reg = self.dataset_config.is_reg
|
self.is_reg = self.dataset_config.is_reg
|
||||||
|
self.prior_reg = self.dataset_config.prior_reg
|
||||||
self.tensor: Union[torch.Tensor, None] = None
|
self.tensor: Union[torch.Tensor, None] = None
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user