Massive speed increases and ram optimizations

This commit is contained in:
Jaret Burkett
2023-10-10 06:07:55 -06:00
parent f4c90bb589
commit 63ceffae24
7 changed files with 294 additions and 185 deletions

View File

@@ -94,85 +94,94 @@ class SDTrainer(BaseSDTrainProcess):
def hook_train_loop(self, batch):
self.timer.start('preprocess_batch')
dtype = get_torch_dtype(self.train_config.dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
network_weight_list = batch.get_network_weight_list()
self.timer.stop('preprocess_batch')
with torch.no_grad():
adapter_images = None
sigmas = None
if self.adapter:
# todo move this to data loader
if batch.control_tensor is not None:
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
else:
adapter_images = self.get_adapter_images(batch)
# not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
with self.timer('get_adapter_images'):
# todo move this to data loader
if batch.control_tensor is not None:
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
else:
adapter_images = self.get_adapter_images(batch)
# not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
mask_multiplier = 1.0
if batch.mask_tensor is not None:
# upsampling no supported for bfloat16
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
mask_multiplier = torch.nn.functional.interpolate(
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
)
# expand to match latents
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
with self.timer('get_mask_multiplier'):
# upsampling no supported for bfloat16
mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
# scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
mask_multiplier = torch.nn.functional.interpolate(
mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
)
# expand to match latents
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
# flush()
self.optimizer.zero_grad()
with self.timer('grad_setup'):
self.optimizer.zero_grad()
# text encoding
grad_on_text_encoder = False
if self.train_config.train_text_encoder:
grad_on_text_encoder = True
# text encoding
grad_on_text_encoder = False
if self.train_config.train_text_encoder:
grad_on_text_encoder = True
if self.embedding:
grad_on_text_encoder = True
if self.embedding:
grad_on_text_encoder = True
# have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None:
network = self.network
else:
network = BlankNetwork()
# have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None:
network = self.network
else:
network = BlankNetwork()
# set the weights
network.multiplier = network_weight_list
# set the weights
network.multiplier = network_weight_list
# activate network if it exits
with network:
with torch.set_grad_enabled(grad_on_text_encoder):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
if not grad_on_text_encoder:
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
with self.timer('encode_prompt'):
with torch.set_grad_enabled(grad_on_text_encoder):
conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype)
if not grad_on_text_encoder:
# detach the embeddings
conditional_embeds = conditional_embeds.detach()
# flush()
pred_kwargs = {}
if self.adapter and isinstance(self.adapter, T2IAdapter):
down_block_additional_residuals = self.adapter(adapter_images)
down_block_additional_residuals = [
sample.to(dtype=dtype) for sample in down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
with self.timer('encode_adapter'):
down_block_additional_residuals = self.adapter(adapter_images)
down_block_additional_residuals = [
sample.to(dtype=dtype) for sample in down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
if self.adapter and isinstance(self.adapter, IPAdapter):
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
with self.timer('encode_adapter'):
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
with self.timer('predict_unet'):
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs
)
# if self.adapter:
# # todo, diffusers does this on t2i training, is it better approach?
@@ -194,43 +203,48 @@ class SDTrainer(BaseSDTrainProcess):
# dim=1,
# )
# else:
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
# multiply by our mask
loss = loss * mask_multiplier
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
# multiply by our mask
loss = loss * mask_multiplier
loss = loss.mean([1, 2, 3])
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
loss = loss.mean()
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
loss.backward()
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# flush()
with self.timer('backward'):
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be
# 0.0 for the backward pass and the gradients will be 0.0
# I spent weeks on fighting this. DON'T DO IT
loss.backward()
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
# flush()
# apply gradients
self.optimizer.step()
self.lr_scheduler.step()
with self.timer('optimizer_step'):
# apply gradients
self.optimizer.step()
with self.timer('scheduler_step'):
self.lr_scheduler.step()
if self.embedding is not None:
# Let's make sure we don't update any embedding weights besides the newly added token
self.embedding.restore_embeddings()
with self.timer('restore_embeddings'):
# Let's make sure we don't update any embedding weights besides the newly added token
self.embedding.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}

View File

@@ -2,6 +2,8 @@ import copy
import json
from collections import OrderedDict
from toolkit.timer import Timer
class BaseProcess(object):
@@ -18,6 +20,9 @@ class BaseProcess(object):
self.raw_process_config = config
self.name = self.get_conf('name', self.job.name)
self.meta = copy.deepcopy(self.job.meta)
self.timer: Timer = Timer(f'{self.name} Timer')
self.performance_log_every = self.get_conf('performance_log_every', 0)
print(json.dumps(self.config, indent=4))
def get_conf(self, key, default=None, required=False, as_type=None):

View File

@@ -144,6 +144,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.named_lora = False
if self.embed_config is not None or is_training_adapter:
self.named_lora = True
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
return generate_image_config_list
@@ -442,98 +443,102 @@ class BaseSDTrainProcess(BaseTrainProcess):
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
prompts = batch.get_caption_list()
is_reg_list = batch.get_is_reg_list()
with self.timer('prepare_prompt'):
prompts = batch.get_caption_list()
is_reg_list = batch.get_is_reg_list()
conditioned_prompts = []
conditioned_prompts = []
for prompt, is_reg in zip(prompts, is_reg_list):
for prompt, is_reg in zip(prompts, is_reg_list):
# make sure the embedding is in the prompts
if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
expand_token=True,
add_if_not_present=not is_reg,
)
# make sure the embedding is in the prompts
if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
expand_token=True,
add_if_not_present=not is_reg,
)
# make sure trigger is in the prompts if not a regularization run
if self.trigger_word is not None:
prompt = self.sd.inject_trigger_into_prompt(
prompt,
trigger=self.trigger_word,
add_if_not_present=not is_reg,
)
conditioned_prompts.append(prompt)
# make sure trigger is in the prompts if not a regularization run
if self.trigger_word is not None:
prompt = self.sd.inject_trigger_into_prompt(
prompt,
trigger=self.trigger_word,
add_if_not_present=not is_reg,
)
conditioned_prompts.append(prompt)
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
if batch.latents is not None:
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents
else:
latents = self.sd.encode_images(imgs)
batch.latents = latents
flush()
with self.timer('prepare_latents'):
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
if batch.latents is not None:
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents
else:
latents = self.sd.encode_images(imgs)
batch.latents = latents
# flush() # todo check performance removing this
batch_size = latents.shape[0]
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
with self.timer('prepare_noise'):
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if self.train_config.content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps
# For more details about why cubic sampling is used for content / structure,
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
# for content / structure, it is best to favor earlier timesteps
# for style, it is best to favor later timesteps
timesteps = torch.rand((batch_size,), device=latents.device)
if self.train_config.content_or_style == 'style':
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
elif self.train_config.content_or_style == 'content':
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timesteps = value_map(
timesteps,
0,
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps
)
timesteps = timesteps.long().clamp(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps - 1
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
else:
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if self.train_config.content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps
# For more details about why cubic sampling is used for content / structure,
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
# for content / structure, it is best to favor earlier timesteps
# for style, it is best to favor later timesteps
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
timesteps = torch.rand((batch_size,), device=latents.device)
if self.train_config.content_or_style == 'style':
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
elif self.train_config.content_or_style == 'content':
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timesteps = value_map(
timesteps,
0,
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps
)
timesteps = timesteps.long().clamp(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps - 1
)
elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
else:
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
# remove grads for these
noisy_latents.requires_grad = False
@@ -933,24 +938,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
# don't do a reg step on sample or save steps as we dont want to normalize on those
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
try:
batch = next(dataloader_iterator_reg)
with self.timer('get_batch:reg'):
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator_reg = iter(dataloader_reg)
trigger_dataloader_setup_epoch(dataloader_reg)
batch = next(dataloader_iterator_reg)
with self.timer('reset_batch:reg'):
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator_reg = iter(dataloader_reg)
trigger_dataloader_setup_epoch(dataloader_reg)
with self.timer('get_batch:reg'):
batch = next(dataloader_iterator_reg)
self.progress_bar.unpause()
is_reg_step = True
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
with self.timer('get_batch'):
batch = next(dataloader_iterator)
except StopIteration:
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
batch = next(dataloader_iterator)
with self.timer('reset_batch'):
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
with self.timer('get_batch'):
batch = next(dataloader_iterator)
self.progress_bar.unpause()
else:
batch = None
@@ -960,7 +972,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network.is_normalizing = True
# flush()
### HOOK ###
self.timer.start('train_loop')
loss_dict = self.hook_train_loop(batch)
self.timer.stop('train_loop')
# flush()
# setup the networks to gradient checkpointing and everything works
@@ -998,11 +1012,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
self.progress_bar.pause()
# log to tensorboard
if self.writer is not None:
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
with self.timer('log_to_tensorboard'):
# log to tensorboard
if self.writer is not None:
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.unpause()
if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0:
self.progress_bar.pause()
# print the timers and clear them
self.timer.print()
self.timer.reset()
self.progress_bar.unpause()
# sets progress bar to match out step
@@ -1012,11 +1034,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
# apply network normalizer if we are using it, not on regularization steps
if self.network is not None and self.network.is_normalizing and not is_reg_step:
self.network.apply_stored_normalizer()
with self.timer('apply_normalizer'):
self.network.apply_stored_normalizer()
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
batch.cleanup()
with self.timer('batch_cleanup'):
batch.cleanup()
# flush every 10 steps
# if self.step_num % 10 == 0:

View File

@@ -448,7 +448,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
return len(self.file_list)
def _get_single_item(self, index) -> 'FileItemDTO':
file_item = self.file_list[index]
file_item = copy.deepcopy(self.file_list[index])
file_item.load_and_process_image(self.transform)
file_item.load_caption(self.caption_dict)
return file_item
@@ -529,14 +529,14 @@ def get_dataloader_from_datasets(
drop_last=False,
shuffle=True,
collate_fn=dto_collation, # Use the custom collate function
num_workers=0
num_workers=4
)
else:
data_loader = DataLoader(
concatenated_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
num_workers=4,
collate_fn=dto_collation
)
return data_loader

View File

@@ -670,7 +670,8 @@ class LatentCachingFileItemDTOMixin:
# load it from disk
state_dict = load_file(
self.get_latent_path(),
device=device if device is not None else self.latent_load_device
# device=device if device is not None else self.latent_load_device
device='cpu'
)
self._encoded_latent = state_dict['latent']
return self._encoded_latent

View File

@@ -53,7 +53,7 @@ class ToolkitModuleMixin:
super().__init__(*args, **kwargs)
self.network_ref: weakref.ref = weakref.ref(network)
self.is_checkpointing = False
self.is_normalizing = False
# self.is_normalizing = False
self.normalize_scaler = 1.0
self._multiplier: Union[float, list, torch.Tensor] = None
@@ -118,7 +118,7 @@ class ToolkitModuleMixin:
multiplier = multiplier.repeat_interleave(num_interleaves)
# multiplier = 1.0
if self.is_normalizing:
if self.network_ref().is_normalizing:
with torch.no_grad():
# do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier
@@ -390,8 +390,8 @@ class ToolkitNetworkMixin:
@is_normalizing.setter
def is_normalizing(self: Network, value: bool):
self._is_normalizing = value
for module in self.get_all_modules():
module.is_normalizing = self._is_normalizing
# for module in self.get_all_modules():
# module.is_normalizing = self._is_normalizing
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
for module in self.get_all_modules():

65
toolkit/timer.py Normal file
View File

@@ -0,0 +1,65 @@
import time
from collections import OrderedDict, deque
class Timer:
def __init__(self, name='Timer', max_buffer=10):
self.name = name
self.max_buffer = max_buffer
self.timers = OrderedDict()
self.active_timers = {}
self.current_timer = None # Used for the context manager functionality
def start(self, timer_name):
if timer_name not in self.timers:
self.timers[timer_name] = deque(maxlen=self.max_buffer)
self.active_timers[timer_name] = time.time()
def cancel(self, timer_name):
"""Cancel an active timer."""
if timer_name in self.active_timers:
del self.active_timers[timer_name]
def stop(self, timer_name):
if timer_name not in self.active_timers:
raise ValueError(f"Timer '{timer_name}' was not started!")
elapsed_time = time.time() - self.active_timers[timer_name]
self.timers[timer_name].append(elapsed_time)
# Clean up active timers
del self.active_timers[timer_name]
# Check if this timer's buffer exceeds max_buffer and remove the oldest if it does
if len(self.timers[timer_name]) > self.max_buffer:
self.timers[timer_name].popleft()
def print(self):
print(f"\nTimer '{self.name}':")
# sort by longest at top
for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True):
avg_time = sum(timings) / len(timings)
print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}")
print('')
def reset(self):
self.timers.clear()
self.active_timers.clear()
def __call__(self, timer_name):
"""Enable the use of the Timer class as a context manager."""
self.current_timer = timer_name
self.start(timer_name)
return self
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
# No exceptions, stop the timer normally
self.stop(self.current_timer)
else:
# There was an exception, cancel the timer
self.cancel(self.current_timer)