Added some experimental training techniques. Ignore for now. Still in testing.

This commit is contained in:
Jaret Burkett
2025-05-21 02:19:54 -06:00
parent 01101be196
commit e5181d23cd
6 changed files with 240 additions and 43 deletions

View File

@@ -931,16 +931,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
dtype=noise.dtype) * 2 - 1
# if self.train_config.random_noise_shift > 0.0:
# # get random noise -1 to 1
# noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
# dtype=noise.dtype) * 2 - 1
# multiply by shift amount
noise_shift *= self.train_config.random_noise_shift
# # multiply by shift amount
# noise_shift *= self.train_config.random_noise_shift
# add to noise
noise += noise_shift
# # add to noise
# noise += noise_shift
if self.train_config.blended_blur_noise:
noise = get_blended_blur_noise(
@@ -1011,6 +1011,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
imgs = None
is_reg = any(batch.get_is_reg_list())
cfm_batch = None
if batch.tensor is not None:
imgs = batch.tensor
imgs = imgs.to(self.device_torch, dtype=dtype)
@@ -1118,8 +1119,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
if timestep_type is None:
timestep_type = self.train_config.timestep_type
if self.train_config.timestep_type == 'next_sample':
# simulate a sample
num_train_timesteps = self.train_config.next_sample_timesteps
timestep_type = 'shift'
patch_size = 1
if self.sd.is_flux:
if self.sd.is_flux or 'flex' in self.sd.arch:
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
patch_size = 2
elif hasattr(self.sd.unet.config, 'patch_size'):
@@ -1142,7 +1148,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
content_or_style = self.train_config.content_or_style_reg
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
if content_or_style in ['style', 'content']:
if self.train_config.timestep_type == 'next_sample':
timestep_indices = torch.randint(
0,
num_train_timesteps - 2, # -1 for 0 idx, -1 so we can step
(batch_size,),
device=self.device_torch
)
timestep_indices = timestep_indices.long()
elif content_or_style in ['style', 'content']:
# this is from diffusers training code
# Cubic sampling for favoring later or earlier timesteps
# For more details about why cubic sampling is used for content / structure,
@@ -1169,7 +1183,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
min_noise_steps + 1,
max_noise_steps - 1
)
elif content_or_style == 'balanced':
if min_noise_steps == max_noise_steps:
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
@@ -1185,16 +1199,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
raise ValueError(f"Unknown content_or_style {content_or_style}")
# do flow matching
# if self.sd.is_flow_matching:
# u = compute_density_for_timestep_sampling(
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
# batch_size=batch_size,
# logit_mean=0.0,
# logit_std=1.0,
# mode_scale=1.29,
# )
# timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
# convert the timestep_indices to a timestep
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
timesteps = torch.stack(timesteps, dim=0)
@@ -1218,8 +1222,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
latents = unaugmented_latents
noise_multiplier = self.train_config.noise_multiplier
s = (noise.shape[0], noise.shape[1], 1, 1)
if len(noise.shape) == 5:
# if we have a 5d tensor, then we need to do it on a per batch item, per channel basis, per frame
s = (noise.shape[0], noise.shape[1], noise.shape[2], 1, 1)
if self.train_config.random_noise_multiplier > 0.0:
# do it on a per batch item, per channel basis
noise_multiplier = 1 + torch.randn(
s,
device=noise.device,
dtype=noise.dtype
) * self.train_config.random_noise_multiplier
noise = noise * noise_multiplier
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
noise_shift = torch.randn(
s,
device=noise.device,
dtype=noise.dtype
) * self.train_config.random_noise_shift
# add to noise
noise += noise_shift
latent_multiplier = self.train_config.latent_multiplier