mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-10 21:19:49 +00:00
Performance optimizations for pre processing the batch
This commit is contained in:
10
README.md
10
README.md
@@ -417,6 +417,16 @@ Everything else should work the same including layer targeting.
|
||||
|
||||
## Updates
|
||||
|
||||
### June 17, 2024
|
||||
- Performance optimizations for batch preparation
|
||||
|
||||
### June 16, 2024
|
||||
- Hide control images in the UI when viewing datasets
|
||||
- WIP on mean flow loss
|
||||
|
||||
### June 12, 2024
|
||||
- Fixed issue that resulted in blank captions in the dataloader
|
||||
|
||||
### June 10, 2024
|
||||
- Decided to keep track up updates in the readme
|
||||
- Added support for SDXL in the UI
|
||||
|
||||
@@ -1009,76 +1009,76 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
return loss
|
||||
|
||||
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
|
||||
self.timer.start('preprocess_batch')
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_raw(batch)
|
||||
batch = self.preprocess_batch(batch)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_processed(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# sanity check
|
||||
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
||||
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
if encoder.dtype != self.sd.te_torch_dtype:
|
||||
encoder.to(self.sd.te_torch_dtype)
|
||||
else:
|
||||
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
|
||||
self.sd.text_encoder.to(self.sd.te_torch_dtype)
|
||||
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
||||
# pick random negative prompts
|
||||
if self.negative_prompt_pool is not None:
|
||||
negative_prompts = []
|
||||
for i in range(noisy_latents.shape[0]):
|
||||
num_neg = random.randint(1, self.train_config.max_negative_prompts)
|
||||
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
|
||||
this_neg_prompt = ', '.join(this_neg_prompts)
|
||||
negative_prompts.append(this_neg_prompt)
|
||||
self.batch_negative_prompt = negative_prompts
|
||||
else:
|
||||
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
|
||||
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
# condition the prompt
|
||||
# todo handle more than one adapter image
|
||||
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
||||
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
has_clip_image_embeds = batch.clip_image_embeds is not None
|
||||
# force it to be true if doing regs as we handle those differently
|
||||
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
|
||||
has_clip_image = True
|
||||
if self._clip_image_embeds_unconditional is not None:
|
||||
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
||||
has_clip_image = False
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError(
|
||||
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
# check if we are matching the adapter assistant
|
||||
if self.assistant_adapter:
|
||||
if self.train_config.match_adapter_chance == 1.0:
|
||||
match_adapter_assist = True
|
||||
elif self.train_config.match_adapter_chance > 0.0:
|
||||
match_adapter_assist = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
) < self.train_config.match_adapter_chance
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
is_reg = False
|
||||
with torch.no_grad():
|
||||
self.timer.start('preprocess_batch')
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_raw(batch)
|
||||
batch = self.preprocess_batch(batch)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
batch = self.adapter.edit_batch_processed(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# sanity check
|
||||
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
|
||||
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
if encoder.dtype != self.sd.te_torch_dtype:
|
||||
encoder.to(self.sd.te_torch_dtype)
|
||||
else:
|
||||
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
|
||||
self.sd.text_encoder.to(self.sd.te_torch_dtype)
|
||||
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
||||
# pick random negative prompts
|
||||
if self.negative_prompt_pool is not None:
|
||||
negative_prompts = []
|
||||
for i in range(noisy_latents.shape[0]):
|
||||
num_neg = random.randint(1, self.train_config.max_negative_prompts)
|
||||
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
|
||||
this_neg_prompt = ', '.join(this_neg_prompts)
|
||||
negative_prompts.append(this_neg_prompt)
|
||||
self.batch_negative_prompt = negative_prompts
|
||||
else:
|
||||
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
|
||||
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
# condition the prompt
|
||||
# todo handle more than one adapter image
|
||||
conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts)
|
||||
|
||||
network_weight_list = batch.get_network_weight_list()
|
||||
if self.train_config.single_item_batching:
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
has_clip_image_embeds = batch.clip_image_embeds is not None
|
||||
# force it to be true if doing regs as we handle those differently
|
||||
if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]):
|
||||
has_clip_image = True
|
||||
if self._clip_image_embeds_unconditional is not None:
|
||||
has_clip_image_embeds = True # we are caching embeds, handle that differently
|
||||
has_clip_image = False
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError(
|
||||
"IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
# check if we are matching the adapter assistant
|
||||
if self.assistant_adapter:
|
||||
if self.train_config.match_adapter_chance == 1.0:
|
||||
match_adapter_assist = True
|
||||
elif self.train_config.match_adapter_chance > 0.0:
|
||||
match_adapter_assist = torch.rand(
|
||||
(1,), device=self.device_torch, dtype=dtype
|
||||
) < self.train_config.match_adapter_chance
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
is_reg = False
|
||||
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
for idx, file_item in enumerate(batch.file_items):
|
||||
if file_item.is_reg:
|
||||
|
||||
@@ -921,7 +921,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
noise = self.get_consistent_noise(latents, batch, dtype=dtype)
|
||||
else:
|
||||
if hasattr(self.sd, 'get_latent_noise_from_latents'):
|
||||
noise = self.sd.get_latent_noise_from_latents(latents).to(self.device_torch, dtype=dtype)
|
||||
noise = self.sd.get_latent_noise_from_latents(
|
||||
latents,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
else:
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
@@ -931,17 +934,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# if self.train_config.random_noise_shift > 0.0:
|
||||
# # get random noise -1 to 1
|
||||
# noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
|
||||
# dtype=noise.dtype) * 2 - 1
|
||||
|
||||
# # multiply by shift amount
|
||||
# noise_shift *= self.train_config.random_noise_shift
|
||||
|
||||
# # add to noise
|
||||
# noise += noise_shift
|
||||
|
||||
if self.train_config.blended_blur_noise:
|
||||
noise = get_blended_blur_noise(
|
||||
@@ -1085,19 +1077,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# we determine noise from the differential of the latents
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
|
||||
batch_size = len(batch.file_items)
|
||||
min_noise_steps = self.train_config.min_denoising_steps
|
||||
max_noise_steps = self.train_config.max_denoising_steps
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
# if we are not training the unet, then we are only doing refiner and do not need to double up
|
||||
if self.train_config.train_unet:
|
||||
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = True
|
||||
else:
|
||||
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = False
|
||||
with self.timer('prepare_scheduler'):
|
||||
|
||||
batch_size = len(batch.file_items)
|
||||
min_noise_steps = self.train_config.min_denoising_steps
|
||||
max_noise_steps = self.train_config.max_denoising_steps
|
||||
if self.model_config.refiner_name_or_path is not None:
|
||||
# if we are not training the unet, then we are only doing refiner and do not need to double up
|
||||
if self.train_config.train_unet:
|
||||
max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = True
|
||||
else:
|
||||
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
|
||||
do_double = False
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
num_train_timesteps = self.train_config.num_train_timesteps
|
||||
|
||||
if self.train_config.noise_scheduler in ['custom_lcm']:
|
||||
@@ -1144,6 +1137,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
num_train_timesteps, device=self.device_torch
|
||||
)
|
||||
with self.timer('prepare_timesteps_indices'):
|
||||
|
||||
content_or_style = self.train_config.content_or_style
|
||||
if is_reg:
|
||||
@@ -1193,20 +1187,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
||||
else:
|
||||
# todo, some schedulers use indices, otheres use timesteps. Not sure what to do here
|
||||
min_idx = min_noise_steps + 1
|
||||
max_idx = max_noise_steps - 1
|
||||
if self.train_config.noise_scheduler == 'flowmatch':
|
||||
# flowmatch uses indices, so we need to use indices
|
||||
min_idx = 0
|
||||
max_idx = max_noise_steps - 1
|
||||
timestep_indices = torch.randint(
|
||||
min_noise_steps + 1,
|
||||
max_noise_steps - 1,
|
||||
min_idx,
|
||||
max_idx,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timestep_indices = timestep_indices.long()
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
||||
|
||||
with self.timer('convert_timestep_indices_to_timesteps'):
|
||||
# convert the timestep_indices to a timestep
|
||||
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||
timesteps = torch.stack(timesteps, dim=0)
|
||||
|
||||
timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()]
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
# get noise
|
||||
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps)
|
||||
|
||||
@@ -1240,6 +1240,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
device=noise.device,
|
||||
dtype=noise.dtype
|
||||
) * self.train_config.random_noise_multiplier
|
||||
|
||||
with self.timer('make_noisy_latents'):
|
||||
|
||||
noise = noise * noise_multiplier
|
||||
|
||||
|
||||
@@ -1763,6 +1763,15 @@ class StableDiffusion:
|
||||
)
|
||||
noise = apply_noise_offset(noise, noise_offset)
|
||||
return noise
|
||||
|
||||
def get_latent_noise_from_latents(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
noise_offset=0.0
|
||||
):
|
||||
noise = torch.randn_like(latents)
|
||||
noise = apply_noise_offset(noise, noise_offset)
|
||||
return noise
|
||||
|
||||
def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
Reference in New Issue
Block a user