Added mask diffirential mask dialation for flex2. Handle video for the i2v adapter

This commit is contained in:
Jaret Burkett
2025-04-10 11:50:01 -06:00
parent 9794416a5d
commit 059155174a
5 changed files with 118 additions and 3 deletions

View File

@@ -771,7 +771,11 @@ class SDTrainer(BaseSDTrainProcess):
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
self.timer.start('preprocess_batch')
if isinstance(self.adapter, CustomAdapter):
batch = self.adapter.edit_batch_raw(batch)
batch = self.preprocess_batch(batch)
if isinstance(self.adapter, CustomAdapter):
batch = self.adapter.edit_batch_processed(batch)
dtype = get_torch_dtype(self.train_config.dtype)
# sanity check
if self.sd.vae.dtype != self.sd.vae_torch_dtype: