From a767b82b60e74c00d8d2cf4933129f51b2fe23aa Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 25 Dec 2025 16:57:34 +0000 Subject: [PATCH] Fixed issue with new logger when ooming --- jobs/process/BaseSDTrainProcess.py | 17 ++++++++++------- toolkit/data_transfer_object/data_loader.py | 9 +++++++-- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index f90fc8f8..91f62a5a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -2212,6 +2212,7 @@ class BaseSDTrainProcess(BaseTrainProcess): with torch.no_grad(): # torch.cuda.empty_cache() # if optimizer has get_lrs method, then use it + learning_rate = 0.0 if not did_oom and loss_dict is not None: if hasattr(optimizer, 'get_avg_learning_rate'): learning_rate = optimizer.get_avg_learning_rate() @@ -2282,9 +2283,10 @@ class BaseSDTrainProcess(BaseTrainProcess): # log to tensorboard if self.accelerator.is_main_process: if self.writer is not None: - for key, value in loss_dict.items(): - self.writer.add_scalar(f"{key}", value, self.step_num) - self.writer.add_scalar(f"lr", learning_rate, self.step_num) + if loss_dict is not None: + for key, value in loss_dict.items(): + self.writer.add_scalar(f"{key}", value, self.step_num) + self.writer.add_scalar(f"lr", learning_rate, self.step_num) if self.progress_bar is not None: self.progress_bar.unpause() @@ -2293,10 +2295,11 @@ class BaseSDTrainProcess(BaseTrainProcess): self.logger.log({ 'learning_rate': learning_rate, }) - for key, value in loss_dict.items(): - self.logger.log({ - f'loss/{key}': value, - }) + if loss_dict is not None: + for key, value in loss_dict.items(): + self.logger.log({ + f'loss/{key}': value, + }) elif self.logging_config.log_every is None: if self.accelerator.is_main_process: # log every step diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index ca27dd47..d72364bd 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -295,9 +295,14 @@ class DataLoaderBatchDTO: prompt_embeds_list = [] for x in self.file_items: if x.prompt_embeds is None: - prompt_embeds_list.append(base_prompt_embeds) + y = base_prompt_embeds else: - prompt_embeds_list.append(x.prompt_embeds) + y = x.prompt_embeds + if x.text_embedding_space_version == "zimage": + # z image needs to be a list if it is not already + if not isinstance(y.text_embeds, list): + y.text_embeds = [y.text_embeds] + prompt_embeds_list.append(y) self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)