mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed issue with new logger when ooming
This commit is contained in:
@@ -2212,6 +2212,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# torch.cuda.empty_cache()
|
# torch.cuda.empty_cache()
|
||||||
# if optimizer has get_lrs method, then use it
|
# if optimizer has get_lrs method, then use it
|
||||||
|
learning_rate = 0.0
|
||||||
if not did_oom and loss_dict is not None:
|
if not did_oom and loss_dict is not None:
|
||||||
if hasattr(optimizer, 'get_avg_learning_rate'):
|
if hasattr(optimizer, 'get_avg_learning_rate'):
|
||||||
learning_rate = optimizer.get_avg_learning_rate()
|
learning_rate = optimizer.get_avg_learning_rate()
|
||||||
@@ -2282,9 +2283,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# log to tensorboard
|
# log to tensorboard
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
for key, value in loss_dict.items():
|
if loss_dict is not None:
|
||||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
for key, value in loss_dict.items():
|
||||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||||
|
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||||
if self.progress_bar is not None:
|
if self.progress_bar is not None:
|
||||||
self.progress_bar.unpause()
|
self.progress_bar.unpause()
|
||||||
|
|
||||||
@@ -2293,10 +2295,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.logger.log({
|
self.logger.log({
|
||||||
'learning_rate': learning_rate,
|
'learning_rate': learning_rate,
|
||||||
})
|
})
|
||||||
for key, value in loss_dict.items():
|
if loss_dict is not None:
|
||||||
self.logger.log({
|
for key, value in loss_dict.items():
|
||||||
f'loss/{key}': value,
|
self.logger.log({
|
||||||
})
|
f'loss/{key}': value,
|
||||||
|
})
|
||||||
elif self.logging_config.log_every is None:
|
elif self.logging_config.log_every is None:
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
# log every step
|
# log every step
|
||||||
|
|||||||
@@ -295,9 +295,14 @@ class DataLoaderBatchDTO:
|
|||||||
prompt_embeds_list = []
|
prompt_embeds_list = []
|
||||||
for x in self.file_items:
|
for x in self.file_items:
|
||||||
if x.prompt_embeds is None:
|
if x.prompt_embeds is None:
|
||||||
prompt_embeds_list.append(base_prompt_embeds)
|
y = base_prompt_embeds
|
||||||
else:
|
else:
|
||||||
prompt_embeds_list.append(x.prompt_embeds)
|
y = x.prompt_embeds
|
||||||
|
if x.text_embedding_space_version == "zimage":
|
||||||
|
# z image needs to be a list if it is not already
|
||||||
|
if not isinstance(y.text_embeds, list):
|
||||||
|
y.text_embeds = [y.text_embeds]
|
||||||
|
prompt_embeds_list.append(y)
|
||||||
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
|
self.prompt_embeds = concat_prompt_embeds(prompt_embeds_list)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user