mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Massive speed increases and ram optimizations
This commit is contained in:
@@ -94,85 +94,94 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
|
||||
self.timer.start('preprocess_batch')
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
with torch.no_grad():
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
if self.adapter:
|
||||
# todo move this to data loader
|
||||
if batch.control_tensor is not None:
|
||||
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
else:
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
with self.timer('get_adapter_images'):
|
||||
# todo move this to data loader
|
||||
if batch.control_tensor is not None:
|
||||
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
else:
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
mask_multiplier = 1.0
|
||||
if batch.mask_tensor is not None:
|
||||
# upsampling no supported for bfloat16
|
||||
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||
mask_multiplier = torch.nn.functional.interpolate(
|
||||
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
||||
)
|
||||
# expand to match latents
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||
with self.timer('get_mask_multiplier'):
|
||||
# upsampling no supported for bfloat16
|
||||
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||
mask_multiplier = torch.nn.functional.interpolate(
|
||||
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
||||
)
|
||||
# expand to match latents
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
# flush()
|
||||
self.optimizer.zero_grad()
|
||||
with self.timer('grad_setup'):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
if self.train_config.train_text_encoder:
|
||||
grad_on_text_encoder = True
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
if self.train_config.train_text_encoder:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.embedding:
|
||||
grad_on_text_encoder = True
|
||||
if self.embedding:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
||||
if self.network is not None:
|
||||
network = self.network
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
||||
if self.network is not None:
|
||||
network = self.network
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
|
||||
# set the weights
|
||||
network.multiplier = network_weight_list
|
||||
# set the weights
|
||||
network.multiplier = network_weight_list
|
||||
|
||||
# activate network if it exits
|
||||
with network:
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
|
||||
if not grad_on_text_encoder:
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
with self.timer('encode_prompt'):
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
|
||||
if not grad_on_text_encoder:
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if self.adapter and isinstance(self.adapter, T2IAdapter):
|
||||
down_block_additional_residuals = self.adapter(adapter_images)
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) for sample in down_block_additional_residuals
|
||||
]
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = self.adapter(adapter_images)
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) for sample in down_block_additional_residuals
|
||||
]
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
with self.timer('encode_adapter'):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# if self.adapter:
|
||||
# # todo, diffusers does this on t2i training, is it better approach?
|
||||
@@ -194,43 +203,48 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# dim=1,
|
||||
# )
|
||||
# else:
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
loss = loss.mean()
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# flush()
|
||||
with self.timer('backward'):
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# flush()
|
||||
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
with self.timer('scheduler_step'):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.embedding is not None:
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
self.embedding.restore_embeddings()
|
||||
with self.timer('restore_embeddings'):
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
self.embedding.restore_embeddings()
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
|
||||
@@ -2,6 +2,8 @@ import copy
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.timer import Timer
|
||||
|
||||
|
||||
class BaseProcess(object):
|
||||
|
||||
@@ -18,6 +20,9 @@ class BaseProcess(object):
|
||||
self.raw_process_config = config
|
||||
self.name = self.get_conf('name', self.job.name)
|
||||
self.meta = copy.deepcopy(self.job.meta)
|
||||
self.timer: Timer = Timer(f'{self.name} Timer')
|
||||
self.performance_log_every = self.get_conf('performance_log_every', 0)
|
||||
|
||||
print(json.dumps(self.config, indent=4))
|
||||
|
||||
def get_conf(self, key, default=None, required=False, as_type=None):
|
||||
|
||||
@@ -144,6 +144,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.named_lora = False
|
||||
if self.embed_config is not None or is_training_adapter:
|
||||
self.named_lora = True
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
return generate_image_config_list
|
||||
@@ -442,98 +443,102 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
prompts = batch.get_caption_list()
|
||||
is_reg_list = batch.get_is_reg_list()
|
||||
with self.timer('prepare_prompt'):
|
||||
prompts = batch.get_caption_list()
|
||||
is_reg_list = batch.get_is_reg_list()
|
||||
|
||||
conditioned_prompts = []
|
||||
conditioned_prompts = []
|
||||
|
||||
for prompt, is_reg in zip(prompts, is_reg_list):
|
||||
for prompt, is_reg in zip(prompts, is_reg_list):
|
||||
|
||||
# make sure the embedding is in the prompts
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
expand_token=True,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
# make sure the embedding is in the prompts
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
expand_token=True,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
|
||||
# make sure trigger is in the prompts if not a regularization run
|
||||
if self.trigger_word is not None:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt,
|
||||
trigger=self.trigger_word,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
conditioned_prompts.append(prompt)
|
||||
# make sure trigger is in the prompts if not a regularization run
|
||||
if self.trigger_word is not None:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt,
|
||||
trigger=self.trigger_word,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
conditioned_prompts.append(prompt)
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = None
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if batch.latents is not None:
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
batch.latents = latents
|
||||
else:
|
||||
latents = self.sd.encode_images(imgs)
|
||||
batch.latents = latents
|
||||
flush()
|
||||
with self.timer('prepare_latents'):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = None
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if batch.latents is not None:
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
batch.latents = latents
|
||||
else:
|
||||
latents = self.sd.encode_images(imgs)
|
||||
batch.latents = latents
|
||||
# flush() # todo check performance removing this
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
with self.timer('prepare_noise'):
|
||||
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if self.train_config.content_or_style in ['style', 'content']:
|
||||
# this is from diffusers training code
|
||||
# Cubic sampling for favoring later or earlier timesteps
|
||||
# For more details about why cubic sampling is used for content / structure,
|
||||
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
|
||||
|
||||
# for content / structure, it is best to favor earlier timesteps
|
||||
# for style, it is best to favor later timesteps
|
||||
|
||||
timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if self.train_config.content_or_style == 'style':
|
||||
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'content':
|
||||
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
timesteps = value_map(
|
||||
timesteps,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps
|
||||
)
|
||||
timesteps = timesteps.long().clamp(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps - 1
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
timesteps = torch.randint(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if self.train_config.content_or_style in ['style', 'content']:
|
||||
# this is from diffusers training code
|
||||
# Cubic sampling for favoring later or earlier timesteps
|
||||
# For more details about why cubic sampling is used for content / structure,
|
||||
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
# for content / structure, it is best to favor earlier timesteps
|
||||
# for style, it is best to favor later timesteps
|
||||
|
||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if self.train_config.content_or_style == 'style':
|
||||
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'content':
|
||||
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
timesteps = value_map(
|
||||
timesteps,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps
|
||||
)
|
||||
timesteps = timesteps.long().clamp(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps - 1
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
timesteps = torch.randint(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
@@ -933,24 +938,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# don't do a reg step on sample or save steps as we dont want to normalize on those
|
||||
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
|
||||
try:
|
||||
batch = next(dataloader_iterator_reg)
|
||||
with self.timer('get_batch:reg'):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
trigger_dataloader_setup_epoch(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
with self.timer('reset_batch:reg'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
trigger_dataloader_setup_epoch(dataloader_reg)
|
||||
|
||||
with self.timer('get_batch:reg'):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
self.progress_bar.unpause()
|
||||
is_reg_step = True
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
with self.timer('reset_batch'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
self.progress_bar.unpause()
|
||||
else:
|
||||
batch = None
|
||||
@@ -960,7 +972,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.network.is_normalizing = True
|
||||
# flush()
|
||||
### HOOK ###
|
||||
self.timer.start('train_loop')
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
self.timer.stop('train_loop')
|
||||
# flush()
|
||||
# setup the networks to gradient checkpointing and everything works
|
||||
|
||||
@@ -998,11 +1012,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
self.progress_bar.pause()
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
with self.timer('log_to_tensorboard'):
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0:
|
||||
self.progress_bar.pause()
|
||||
# print the timers and clear them
|
||||
self.timer.print()
|
||||
self.timer.reset()
|
||||
self.progress_bar.unpause()
|
||||
|
||||
# sets progress bar to match out step
|
||||
@@ -1012,11 +1034,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# apply network normalizer if we are using it, not on regularization steps
|
||||
if self.network is not None and self.network.is_normalizing and not is_reg_step:
|
||||
self.network.apply_stored_normalizer()
|
||||
with self.timer('apply_normalizer'):
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
batch.cleanup()
|
||||
with self.timer('batch_cleanup'):
|
||||
batch.cleanup()
|
||||
|
||||
# flush every 10 steps
|
||||
# if self.step_num % 10 == 0:
|
||||
|
||||
@@ -448,7 +448,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
return len(self.file_list)
|
||||
|
||||
def _get_single_item(self, index) -> 'FileItemDTO':
|
||||
file_item = self.file_list[index]
|
||||
file_item = copy.deepcopy(self.file_list[index])
|
||||
file_item.load_and_process_image(self.transform)
|
||||
file_item.load_caption(self.caption_dict)
|
||||
return file_item
|
||||
@@ -529,14 +529,14 @@ def get_dataloader_from_datasets(
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
collate_fn=dto_collation, # Use the custom collate function
|
||||
num_workers=0
|
||||
num_workers=4
|
||||
)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
num_workers=4,
|
||||
collate_fn=dto_collation
|
||||
)
|
||||
return data_loader
|
||||
|
||||
@@ -670,7 +670,8 @@ class LatentCachingFileItemDTOMixin:
|
||||
# load it from disk
|
||||
state_dict = load_file(
|
||||
self.get_latent_path(),
|
||||
device=device if device is not None else self.latent_load_device
|
||||
# device=device if device is not None else self.latent_load_device
|
||||
device='cpu'
|
||||
)
|
||||
self._encoded_latent = state_dict['latent']
|
||||
return self._encoded_latent
|
||||
|
||||
@@ -53,7 +53,7 @@ class ToolkitModuleMixin:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.network_ref: weakref.ref = weakref.ref(network)
|
||||
self.is_checkpointing = False
|
||||
self.is_normalizing = False
|
||||
# self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
self._multiplier: Union[float, list, torch.Tensor] = None
|
||||
|
||||
@@ -118,7 +118,7 @@ class ToolkitModuleMixin:
|
||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||
# multiplier = 1.0
|
||||
|
||||
if self.is_normalizing:
|
||||
if self.network_ref().is_normalizing:
|
||||
with torch.no_grad():
|
||||
|
||||
# do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier
|
||||
@@ -390,8 +390,8 @@ class ToolkitNetworkMixin:
|
||||
@is_normalizing.setter
|
||||
def is_normalizing(self: Network, value: bool):
|
||||
self._is_normalizing = value
|
||||
for module in self.get_all_modules():
|
||||
module.is_normalizing = self._is_normalizing
|
||||
# for module in self.get_all_modules():
|
||||
# module.is_normalizing = self._is_normalizing
|
||||
|
||||
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
|
||||
for module in self.get_all_modules():
|
||||
|
||||
65
toolkit/timer.py
Normal file
65
toolkit/timer.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import time
|
||||
from collections import OrderedDict, deque
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self, name='Timer', max_buffer=10):
|
||||
self.name = name
|
||||
self.max_buffer = max_buffer
|
||||
self.timers = OrderedDict()
|
||||
self.active_timers = {}
|
||||
self.current_timer = None # Used for the context manager functionality
|
||||
|
||||
def start(self, timer_name):
|
||||
if timer_name not in self.timers:
|
||||
self.timers[timer_name] = deque(maxlen=self.max_buffer)
|
||||
self.active_timers[timer_name] = time.time()
|
||||
|
||||
def cancel(self, timer_name):
|
||||
"""Cancel an active timer."""
|
||||
if timer_name in self.active_timers:
|
||||
del self.active_timers[timer_name]
|
||||
|
||||
def stop(self, timer_name):
|
||||
if timer_name not in self.active_timers:
|
||||
raise ValueError(f"Timer '{timer_name}' was not started!")
|
||||
|
||||
elapsed_time = time.time() - self.active_timers[timer_name]
|
||||
self.timers[timer_name].append(elapsed_time)
|
||||
|
||||
# Clean up active timers
|
||||
del self.active_timers[timer_name]
|
||||
|
||||
# Check if this timer's buffer exceeds max_buffer and remove the oldest if it does
|
||||
if len(self.timers[timer_name]) > self.max_buffer:
|
||||
self.timers[timer_name].popleft()
|
||||
|
||||
def print(self):
|
||||
print(f"\nTimer '{self.name}':")
|
||||
# sort by longest at top
|
||||
for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True):
|
||||
avg_time = sum(timings) / len(timings)
|
||||
print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}")
|
||||
|
||||
print('')
|
||||
|
||||
def reset(self):
|
||||
self.timers.clear()
|
||||
self.active_timers.clear()
|
||||
|
||||
def __call__(self, timer_name):
|
||||
"""Enable the use of the Timer class as a context manager."""
|
||||
self.current_timer = timer_name
|
||||
self.start(timer_name)
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type is None:
|
||||
# No exceptions, stop the timer normally
|
||||
self.stop(self.current_timer)
|
||||
else:
|
||||
# There was an exception, cancel the timer
|
||||
self.cancel(self.current_timer)
|
||||
Reference in New Issue
Block a user