Massive speed increases and ram optimizations

This commit is contained in:
Jaret Burkett
2023-10-10 06:07:55 -06:00
parent f4c90bb589
commit 63ceffae24
7 changed files with 294 additions and 185 deletions

View File

@@ -2,6 +2,8 @@ import copy
import json
from collections import OrderedDict
from toolkit.timer import Timer
class BaseProcess(object):
@@ -18,6 +20,9 @@ class BaseProcess(object):
self.raw_process_config = config
self.name = self.get_conf('name', self.job.name)
self.meta = copy.deepcopy(self.job.meta)
self.timer: Timer = Timer(f'{self.name} Timer')
self.performance_log_every = self.get_conf('performance_log_every', 0)
print(json.dumps(self.config, indent=4))
def get_conf(self, key, default=None, required=False, as_type=None):

View File

@@ -144,6 +144,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.named_lora = False
if self.embed_config is not None or is_training_adapter:
self.named_lora = True
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
return generate_image_config_list
@@ -442,98 +443,102 @@ class BaseSDTrainProcess(BaseTrainProcess):
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
prompts = batch.get_caption_list()
is_reg_list = batch.get_is_reg_list()
with self.timer('prepare_prompt'):
prompts = batch.get_caption_list()
is_reg_list = batch.get_is_reg_list()
conditioned_prompts = []
conditioned_prompts = []
for prompt, is_reg in zip(prompts, is_reg_list):
for prompt, is_reg in zip(prompts, is_reg_list):
# make sure the embedding is in the prompts
if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
expand_token=True,
add_if_not_present=not is_reg,
)
# make sure the embedding is in the prompts
if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
expand_token=True,
add_if_not_present=not is_reg,
)
# make sure trigger is in the prompts if not a regularization run
if self.trigger_word is not None:
prompt = self.sd.inject_trigger_into_prompt(
prompt,
trigger=self.trigger_word,
add_if_not_present=not is_reg,
)
conditioned_prompts.append(prompt)
# make sure trigger is in the prompts if not a regularization run
if self.trigger_word is not None:
prompt = self.sd.inject_trigger_into_prompt(
prompt,
trigger=self.trigger_word,
add_if_not_present=not is_reg,
)
conditioned_prompts.append(prompt)
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
if batch.latents is not None:
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents
else:
latents = self.sd.encode_images(imgs)
batch.latents = latents
flush()
with self.timer('prepare_latents'):
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
if batch.latents is not None:
latents = batch.latents.to(self.device_torch, dtype=dtype)
batch.latents = latents
else:
latents = self.sd.encode_images(imgs)
batch.latents = latents
# flush() # todo check performance removing this
batch_size = latents.shape[0]
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
with self.timer('prepare_noise'):
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if self.train_config.content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps
# For more details about why cubic sampling is used for content / structure,
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
# for content / structure, it is best to favor earlier timesteps
# for style, it is best to favor later timesteps
timesteps = torch.rand((batch_size,), device=latents.device)
if self.train_config.content_or_style == 'style':
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
elif self.train_config.content_or_style == 'content':
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timesteps = value_map(
timesteps,
0,
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps
)
timesteps = timesteps.long().clamp(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps - 1
self.sd.noise_scheduler.set_timesteps(
1000, device=self.device_torch
)
elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
else:
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if self.train_config.content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps
# For more details about why cubic sampling is used for content / structure,
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
# for content / structure, it is best to favor earlier timesteps
# for style, it is best to favor later timesteps
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
timesteps = torch.rand((batch_size,), device=latents.device)
if self.train_config.content_or_style == 'style':
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
elif self.train_config.content_or_style == 'content':
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timesteps = value_map(
timesteps,
0,
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps
)
timesteps = timesteps.long().clamp(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps - 1
)
elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint(
self.train_config.min_denoising_steps,
self.train_config.max_denoising_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
else:
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
# remove grads for these
noisy_latents.requires_grad = False
@@ -933,24 +938,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
# don't do a reg step on sample or save steps as we dont want to normalize on those
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
try:
batch = next(dataloader_iterator_reg)
with self.timer('get_batch:reg'):
batch = next(dataloader_iterator_reg)
except StopIteration:
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator_reg = iter(dataloader_reg)
trigger_dataloader_setup_epoch(dataloader_reg)
batch = next(dataloader_iterator_reg)
with self.timer('reset_batch:reg'):
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator_reg = iter(dataloader_reg)
trigger_dataloader_setup_epoch(dataloader_reg)
with self.timer('get_batch:reg'):
batch = next(dataloader_iterator_reg)
self.progress_bar.unpause()
is_reg_step = True
elif dataloader is not None:
try:
batch = next(dataloader_iterator)
with self.timer('get_batch'):
batch = next(dataloader_iterator)
except StopIteration:
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
batch = next(dataloader_iterator)
with self.timer('reset_batch'):
# hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
with self.timer('get_batch'):
batch = next(dataloader_iterator)
self.progress_bar.unpause()
else:
batch = None
@@ -960,7 +972,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network.is_normalizing = True
# flush()
### HOOK ###
self.timer.start('train_loop')
loss_dict = self.hook_train_loop(batch)
self.timer.stop('train_loop')
# flush()
# setup the networks to gradient checkpointing and everything works
@@ -998,11 +1012,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
self.progress_bar.pause()
# log to tensorboard
if self.writer is not None:
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
with self.timer('log_to_tensorboard'):
# log to tensorboard
if self.writer is not None:
for key, value in loss_dict.items():
self.writer.add_scalar(f"{key}", value, self.step_num)
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
self.progress_bar.unpause()
if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0:
self.progress_bar.pause()
# print the timers and clear them
self.timer.print()
self.timer.reset()
self.progress_bar.unpause()
# sets progress bar to match out step
@@ -1012,11 +1034,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
# apply network normalizer if we are using it, not on regularization steps
if self.network is not None and self.network.is_normalizing and not is_reg_step:
self.network.apply_stored_normalizer()
with self.timer('apply_normalizer'):
self.network.apply_stored_normalizer()
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
batch.cleanup()
with self.timer('batch_cleanup'):
batch.cleanup()
# flush every 10 steps
# if self.step_num % 10 == 0: