Fixed issue with new logger when ooming

This commit is contained in:
Jaret Burkett
2025-12-25 16:57:34 +00:00
parent 8edf1e44c5
commit a767b82b60
2 changed files with 17 additions and 9 deletions

View File

@@ -2212,6 +2212,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
with torch.no_grad(): with torch.no_grad():
# torch.cuda.empty_cache() # torch.cuda.empty_cache()
# if optimizer has get_lrs method, then use it # if optimizer has get_lrs method, then use it
learning_rate = 0.0
if not did_oom and loss_dict is not None: if not did_oom and loss_dict is not None:
if hasattr(optimizer, 'get_avg_learning_rate'): if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate() learning_rate = optimizer.get_avg_learning_rate()
@@ -2282,9 +2283,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# log to tensorboard # log to tensorboard
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
if self.writer is not None: if self.writer is not None:
for key, value in loss_dict.items(): if loss_dict is not None:
self.writer.add_scalar(f"{key}", value, self.step_num) for key, value in loss_dict.items():
self.writer.add_scalar(f"lr", learning_rate, self.step_num) self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
if self.progress_bar is not None: if self.progress_bar is not None:
self.progress_bar.unpause() self.progress_bar.unpause()
@@ -2293,10 +2295,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.logger.log({ self.logger.log({
'learning_rate': learning_rate, 'learning_rate': learning_rate,
}) })
for key, value in loss_dict.items(): if loss_dict is not None:
self.logger.log({ for key, value in loss_dict.items():
f'loss/{key}': value, self.logger.log({
}) f'loss/{key}': value,
})
elif self.logging_config.log_every is None: elif self.logging_config.log_every is None:
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
# log every step # log every step

View File

@@ -295,9 +295,14 @@ class DataLoaderBatchDTO:
prompt_embeds_list = [] prompt_embeds_list = []
for x in self.file_items: for x in self.file_items:
if x.prompt_embeds is None: if x.prompt_embeds is None:
prompt_embeds_list.append(base_prompt_embeds) y = base_prompt_embeds
else: else:
prompt_embeds_list.append(x.prompt_embeds) y = x.prompt_embeds
if x.text_embedding_space_version == "zimage":
# z image needs to be a list if it is not already
if not isinstance(y.text_embeds, list):
y.text_embeds = [y.text_embeds]
prompt_embeds_list.append(y)
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list) self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)