From ab2267498088fb1fa762907750ec5df5ce0d4a14 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 10 Oct 2024 07:31:33 -0600 Subject: [PATCH] Allow for a default caption file in the folder. Minor bug fixes. --- extensions_built_in/sd_trainer/SDTrainer.py | 2 +- jobs/process/BaseSDTrainProcess.py | 5 ++++- toolkit/dataloader_mixins.py | 7 +++++++ toolkit/guidance.py | 6 +++--- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index e5fa9e87..13df73f7 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -63,7 +63,7 @@ class SDTrainer(BaseSDTrainProcess): self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" self.do_grad_scale = True - if self.is_fine_tuning: + if self.is_fine_tuning and self.is_bfloat: self.do_grad_scale = False if self.adapter_config is not None: if self.adapter_config.train: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 5fb0d493..3067ae57 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1661,8 +1661,10 @@ class BaseSDTrainProcess(BaseTrainProcess): batch_list = [] for b in range(self.train_config.gradient_accumulation): + # keep track to alternate on an accumulation step for reg + batch_step = step # don't do a reg step on sample or save steps as we dont want to normalize on those - if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step: + if batch_step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step: try: with self.timer('get_batch:reg'): batch = next(dataloader_iterator_reg) @@ -1698,6 +1700,7 @@ class BaseSDTrainProcess(BaseTrainProcess): else: batch = None batch_list.append(batch) + batch_step += 1 # setup accumulation if self.train_config.gradient_accumulation_steps == -1: diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 87e0679b..539ba479 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -121,6 +121,9 @@ class CaptionMixin: prompt_path = path_no_ext + '.' + ext if os.path.exists(prompt_path): break + + # allow folders to have a default prompt + default_prompt_path = os.path.join(os.path.dirname(img_path), 'default.txt') if os.path.exists(prompt_path): with open(prompt_path, 'r', encoding='utf-8') as f: @@ -131,6 +134,10 @@ class CaptionMixin: if 'caption' in prompt: prompt = prompt['caption'] + prompt = clean_caption(prompt) + elif os.path.exists(default_prompt_path): + with open(default_prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() prompt = clean_caption(prompt) else: prompt = '' diff --git a/toolkit/guidance.py b/toolkit/guidance.py index 0f718614..dcf28204 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -481,9 +481,9 @@ def get_guided_loss_polarity( loss = pred_loss + pred_neg_loss - if sd.is_flow_matching: - timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach() - loss = loss * timestep_weight + # if sd.is_flow_matching: + # timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach() + # loss = loss * timestep_weight loss = loss.mean([1, 2, 3])