mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Massive speed increases and ram optimizations
This commit is contained in:
@@ -94,85 +94,94 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
|
||||
self.timer.start('preprocess_batch')
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
with torch.no_grad():
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
if self.adapter:
|
||||
# todo move this to data loader
|
||||
if batch.control_tensor is not None:
|
||||
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
else:
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
with self.timer('get_adapter_images'):
|
||||
# todo move this to data loader
|
||||
if batch.control_tensor is not None:
|
||||
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
else:
|
||||
adapter_images = self.get_adapter_images(batch)
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
mask_multiplier = 1.0
|
||||
if batch.mask_tensor is not None:
|
||||
# upsampling no supported for bfloat16
|
||||
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||
mask_multiplier = torch.nn.functional.interpolate(
|
||||
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
||||
)
|
||||
# expand to match latents
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||
with self.timer('get_mask_multiplier'):
|
||||
# upsampling no supported for bfloat16
|
||||
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
|
||||
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
|
||||
mask_multiplier = torch.nn.functional.interpolate(
|
||||
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
|
||||
)
|
||||
# expand to match latents
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
# flush()
|
||||
self.optimizer.zero_grad()
|
||||
with self.timer('grad_setup'):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
if self.train_config.train_text_encoder:
|
||||
grad_on_text_encoder = True
|
||||
# text encoding
|
||||
grad_on_text_encoder = False
|
||||
if self.train_config.train_text_encoder:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.embedding:
|
||||
grad_on_text_encoder = True
|
||||
if self.embedding:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
||||
if self.network is not None:
|
||||
network = self.network
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
||||
if self.network is not None:
|
||||
network = self.network
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
|
||||
# set the weights
|
||||
network.multiplier = network_weight_list
|
||||
# set the weights
|
||||
network.multiplier = network_weight_list
|
||||
|
||||
# activate network if it exits
|
||||
with network:
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
|
||||
if not grad_on_text_encoder:
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
with self.timer('encode_prompt'):
|
||||
with torch.set_grad_enabled(grad_on_text_encoder):
|
||||
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
|
||||
if not grad_on_text_encoder:
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if self.adapter and isinstance(self.adapter, T2IAdapter):
|
||||
down_block_additional_residuals = self.adapter(adapter_images)
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) for sample in down_block_additional_residuals
|
||||
]
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = self.adapter(adapter_images)
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) for sample in down_block_additional_residuals
|
||||
]
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
with self.timer('encode_adapter'):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
|
||||
# if self.adapter:
|
||||
# # todo, diffusers does this on t2i training, is it better approach?
|
||||
@@ -194,43 +203,48 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# dim=1,
|
||||
# )
|
||||
# else:
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
loss = loss.mean()
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# flush()
|
||||
with self.timer('backward'):
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
# 0.0 for the backward pass and the gradients will be 0.0
|
||||
# I spent weeks on fighting this. DON'T DO IT
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# flush()
|
||||
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
self.optimizer.step()
|
||||
with self.timer('scheduler_step'):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.embedding is not None:
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
self.embedding.restore_embeddings()
|
||||
with self.timer('restore_embeddings'):
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
self.embedding.restore_embeddings()
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
|
||||
Reference in New Issue
Block a user