mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Massive speed increases and ram optimizations
This commit is contained in:
@@ -2,6 +2,8 @@ import copy
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.timer import Timer
|
||||
|
||||
|
||||
class BaseProcess(object):
|
||||
|
||||
@@ -18,6 +20,9 @@ class BaseProcess(object):
|
||||
self.raw_process_config = config
|
||||
self.name = self.get_conf('name', self.job.name)
|
||||
self.meta = copy.deepcopy(self.job.meta)
|
||||
self.timer: Timer = Timer(f'{self.name} Timer')
|
||||
self.performance_log_every = self.get_conf('performance_log_every', 0)
|
||||
|
||||
print(json.dumps(self.config, indent=4))
|
||||
|
||||
def get_conf(self, key, default=None, required=False, as_type=None):
|
||||
|
||||
@@ -144,6 +144,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.named_lora = False
|
||||
if self.embed_config is not None or is_training_adapter:
|
||||
self.named_lora = True
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
return generate_image_config_list
|
||||
@@ -442,98 +443,102 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
prompts = batch.get_caption_list()
|
||||
is_reg_list = batch.get_is_reg_list()
|
||||
with self.timer('prepare_prompt'):
|
||||
prompts = batch.get_caption_list()
|
||||
is_reg_list = batch.get_is_reg_list()
|
||||
|
||||
conditioned_prompts = []
|
||||
conditioned_prompts = []
|
||||
|
||||
for prompt, is_reg in zip(prompts, is_reg_list):
|
||||
for prompt, is_reg in zip(prompts, is_reg_list):
|
||||
|
||||
# make sure the embedding is in the prompts
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
expand_token=True,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
# make sure the embedding is in the prompts
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
expand_token=True,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
|
||||
# make sure trigger is in the prompts if not a regularization run
|
||||
if self.trigger_word is not None:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt,
|
||||
trigger=self.trigger_word,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
conditioned_prompts.append(prompt)
|
||||
# make sure trigger is in the prompts if not a regularization run
|
||||
if self.trigger_word is not None:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt,
|
||||
trigger=self.trigger_word,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
conditioned_prompts.append(prompt)
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = None
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if batch.latents is not None:
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
batch.latents = latents
|
||||
else:
|
||||
latents = self.sd.encode_images(imgs)
|
||||
batch.latents = latents
|
||||
flush()
|
||||
with self.timer('prepare_latents'):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = None
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
if batch.latents is not None:
|
||||
latents = batch.latents.to(self.device_torch, dtype=dtype)
|
||||
batch.latents = latents
|
||||
else:
|
||||
latents = self.sd.encode_images(imgs)
|
||||
batch.latents = latents
|
||||
# flush() # todo check performance removing this
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
with self.timer('prepare_noise'):
|
||||
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if self.train_config.content_or_style in ['style', 'content']:
|
||||
# this is from diffusers training code
|
||||
# Cubic sampling for favoring later or earlier timesteps
|
||||
# For more details about why cubic sampling is used for content / structure,
|
||||
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
|
||||
|
||||
# for content / structure, it is best to favor earlier timesteps
|
||||
# for style, it is best to favor later timesteps
|
||||
|
||||
timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if self.train_config.content_or_style == 'style':
|
||||
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'content':
|
||||
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
timesteps = value_map(
|
||||
timesteps,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps
|
||||
)
|
||||
timesteps = timesteps.long().clamp(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps - 1
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
1000, device=self.device_torch
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
timesteps = torch.randint(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if self.train_config.content_or_style in ['style', 'content']:
|
||||
# this is from diffusers training code
|
||||
# Cubic sampling for favoring later or earlier timesteps
|
||||
# For more details about why cubic sampling is used for content / structure,
|
||||
# refer to section 3.4 of https://arxiv.org/abs/2302.08453
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
# for content / structure, it is best to favor earlier timesteps
|
||||
# for style, it is best to favor later timesteps
|
||||
|
||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if self.train_config.content_or_style == 'style':
|
||||
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'content':
|
||||
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
timesteps = value_map(
|
||||
timesteps,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps
|
||||
)
|
||||
timesteps = timesteps.long().clamp(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps - 1
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
timesteps = torch.randint(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.max_denoising_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
@@ -933,24 +938,31 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# don't do a reg step on sample or save steps as we dont want to normalize on those
|
||||
if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
|
||||
try:
|
||||
batch = next(dataloader_iterator_reg)
|
||||
with self.timer('get_batch:reg'):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
trigger_dataloader_setup_epoch(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
with self.timer('reset_batch:reg'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
trigger_dataloader_setup_epoch(dataloader_reg)
|
||||
|
||||
with self.timer('get_batch:reg'):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
self.progress_bar.unpause()
|
||||
is_reg_step = True
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
with self.timer('reset_batch'):
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
with self.timer('get_batch'):
|
||||
batch = next(dataloader_iterator)
|
||||
self.progress_bar.unpause()
|
||||
else:
|
||||
batch = None
|
||||
@@ -960,7 +972,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.network.is_normalizing = True
|
||||
# flush()
|
||||
### HOOK ###
|
||||
self.timer.start('train_loop')
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
self.timer.stop('train_loop')
|
||||
# flush()
|
||||
# setup the networks to gradient checkpointing and everything works
|
||||
|
||||
@@ -998,11 +1012,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
self.progress_bar.pause()
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
with self.timer('log_to_tensorboard'):
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
for key, value in loss_dict.items():
|
||||
self.writer.add_scalar(f"{key}", value, self.step_num)
|
||||
self.writer.add_scalar(f"lr", learning_rate, self.step_num)
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0:
|
||||
self.progress_bar.pause()
|
||||
# print the timers and clear them
|
||||
self.timer.print()
|
||||
self.timer.reset()
|
||||
self.progress_bar.unpause()
|
||||
|
||||
# sets progress bar to match out step
|
||||
@@ -1012,11 +1034,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# apply network normalizer if we are using it, not on regularization steps
|
||||
if self.network is not None and self.network.is_normalizing and not is_reg_step:
|
||||
self.network.apply_stored_normalizer()
|
||||
with self.timer('apply_normalizer'):
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
batch.cleanup()
|
||||
with self.timer('batch_cleanup'):
|
||||
batch.cleanup()
|
||||
|
||||
# flush every 10 steps
|
||||
# if self.step_num % 10 == 0:
|
||||
|
||||
Reference in New Issue
Block a user