Performance optimizations for pre processing the batch

This commit is contained in:
Jaret Burkett
2025-06-17 07:37:41 -06:00
parent 11f2eee53a
commit 1cc663a664
4 changed files with 120 additions and 99 deletions

View File

@@ -417,6 +417,16 @@ Everything else should work the same including layer targeting.
## Updates
### June 17, 2024
- Performance optimizations for batch preparation
### June 16, 2024
- Hide control images in the UI when viewing datasets
- WIP on mean flow loss
### June 12, 2024
- Fixed issue that resulted in blank captions in the dataloader
### June 10, 2024
- Decided to keep track up updates in the readme
- Added support for SDXL in the UI

View File

@@ -1009,6 +1009,7 @@ class SDTrainer(BaseSDTrainProcess):
return loss
def train_single_accumulation(self, batch: DataLoaderBatchDTO):
with torch.no_grad():
self.timer.start('preprocess_batch')
if isinstance(self.adapter, CustomAdapter):
batch = self.adapter.edit_batch_raw(batch)
@@ -1078,7 +1079,6 @@ class SDTrainer(BaseSDTrainProcess):
self.timer.stop('preprocess_batch')
is_reg = False
with torch.no_grad():
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
for idx, file_item in enumerate(batch.file_items):
if file_item.is_reg:

View File

@@ -921,7 +921,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise = self.get_consistent_noise(latents, batch, dtype=dtype)
else:
if hasattr(self.sd, 'get_latent_noise_from_latents'):
noise = self.sd.get_latent_noise_from_latents(latents).to(self.device_torch, dtype=dtype)
noise = self.sd.get_latent_noise_from_latents(
latents,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
else:
# get noise
noise = self.sd.get_latent_noise(
@@ -932,17 +935,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
# if self.train_config.random_noise_shift > 0.0:
# # get random noise -1 to 1
# noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
# dtype=noise.dtype) * 2 - 1
# # multiply by shift amount
# noise_shift *= self.train_config.random_noise_shift
# # add to noise
# noise += noise_shift
if self.train_config.blended_blur_noise:
noise = get_blended_blur_noise(
latents, noise, timestep
@@ -1085,6 +1077,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# we determine noise from the differential of the latents
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
with self.timer('prepare_scheduler'):
batch_size = len(batch.file_items)
min_noise_steps = self.train_config.min_denoising_steps
max_noise_steps = self.train_config.max_denoising_steps
@@ -1097,7 +1091,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
do_double = False
with self.timer('prepare_noise'):
num_train_timesteps = self.train_config.num_train_timesteps
if self.train_config.noise_scheduler in ['custom_lcm']:
@@ -1144,6 +1137,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch
)
with self.timer('prepare_timesteps_indices'):
content_or_style = self.train_config.content_or_style
if is_reg:
@@ -1193,20 +1187,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
else:
# todo, some schedulers use indices, otheres use timesteps. Not sure what to do here
min_idx = min_noise_steps + 1
max_idx = max_noise_steps - 1
if self.train_config.noise_scheduler == 'flowmatch':
# flowmatch uses indices, so we need to use indices
min_idx = 0
max_idx = max_noise_steps - 1
timestep_indices = torch.randint(
min_noise_steps + 1,
max_noise_steps - 1,
min_idx,
max_idx,
(batch_size,),
device=self.device_torch
)
timestep_indices = timestep_indices.long()
else:
raise ValueError(f"Unknown content_or_style {content_or_style}")
with self.timer('convert_timestep_indices_to_timesteps'):
# convert the timestep_indices to a timestep
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
timesteps = torch.stack(timesteps, dim=0)
timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()]
with self.timer('prepare_noise'):
# get noise
noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps)
@@ -1241,6 +1241,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype=noise.dtype
) * self.train_config.random_noise_multiplier
with self.timer('make_noisy_latents'):
noise = noise * noise_multiplier
if self.train_config.random_noise_shift > 0.0:

View File

@@ -1764,6 +1764,15 @@ class StableDiffusion:
noise = apply_noise_offset(noise, noise_offset)
return noise
def get_latent_noise_from_latents(
self,
latents: torch.Tensor,
noise_offset=0.0
):
noise = torch.randn_like(latents)
noise = apply_noise_offset(noise, noise_offset)
return noise
def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
if self.is_xl: