mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added some experimental training techniques. Ignore for now. Still in testing.
This commit is contained in:
@@ -931,16 +931,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.train_config.random_noise_shift > 0.0:
|
||||
# get random noise -1 to 1
|
||||
noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
|
||||
dtype=noise.dtype) * 2 - 1
|
||||
# if self.train_config.random_noise_shift > 0.0:
|
||||
# # get random noise -1 to 1
|
||||
# noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
|
||||
# dtype=noise.dtype) * 2 - 1
|
||||
|
||||
# multiply by shift amount
|
||||
noise_shift *= self.train_config.random_noise_shift
|
||||
# # multiply by shift amount
|
||||
# noise_shift *= self.train_config.random_noise_shift
|
||||
|
||||
# add to noise
|
||||
noise += noise_shift
|
||||
# # add to noise
|
||||
# noise += noise_shift
|
||||
|
||||
if self.train_config.blended_blur_noise:
|
||||
noise = get_blended_blur_noise(
|
||||
@@ -1011,6 +1011,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = None
|
||||
is_reg = any(batch.get_is_reg_list())
|
||||
cfm_batch = None
|
||||
if batch.tensor is not None:
|
||||
imgs = batch.tensor
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
@@ -1118,8 +1119,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if timestep_type is None:
|
||||
timestep_type = self.train_config.timestep_type
|
||||
|
||||
if self.train_config.timestep_type == 'next_sample':
|
||||
# simulate a sample
|
||||
num_train_timesteps = self.train_config.next_sample_timesteps
|
||||
timestep_type = 'shift'
|
||||
|
||||
patch_size = 1
|
||||
if self.sd.is_flux:
|
||||
if self.sd.is_flux or 'flex' in self.sd.arch:
|
||||
# flux is a patch size of 1, but latents are divided by 2, so we need to double it
|
||||
patch_size = 2
|
||||
elif hasattr(self.sd.unet.config, 'patch_size'):
|
||||
@@ -1142,7 +1148,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
content_or_style = self.train_config.content_or_style_reg
|
||||
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if content_or_style in ['style', 'content']:
|
||||
if self.train_config.timestep_type == 'next_sample':
|
||||
timestep_indices = torch.randint(
|
||||
0,
|
||||
num_train_timesteps - 2, # -1 for 0 idx, -1 so we can step
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timestep_indices = timestep_indices.long()
|
||||
elif content_or_style in ['style', 'content']:
|
||||
# this is from diffusers training code
|
||||
# Cubic sampling for favoring later or earlier timesteps
|
||||
# For more details about why cubic sampling is used for content / structure,
|
||||
@@ -1169,7 +1183,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
min_noise_steps + 1,
|
||||
max_noise_steps - 1
|
||||
)
|
||||
|
||||
|
||||
elif content_or_style == 'balanced':
|
||||
if min_noise_steps == max_noise_steps:
|
||||
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
||||
@@ -1185,16 +1199,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
||||
|
||||
# do flow matching
|
||||
# if self.sd.is_flow_matching:
|
||||
# u = compute_density_for_timestep_sampling(
|
||||
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
|
||||
# batch_size=batch_size,
|
||||
# logit_mean=0.0,
|
||||
# logit_std=1.0,
|
||||
# mode_scale=1.29,
|
||||
# )
|
||||
# timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
|
||||
# convert the timestep_indices to a timestep
|
||||
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||
timesteps = torch.stack(timesteps, dim=0)
|
||||
@@ -1218,8 +1222,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
latents = unaugmented_latents
|
||||
|
||||
noise_multiplier = self.train_config.noise_multiplier
|
||||
|
||||
s = (noise.shape[0], noise.shape[1], 1, 1)
|
||||
if len(noise.shape) == 5:
|
||||
# if we have a 5d tensor, then we need to do it on a per batch item, per channel basis, per frame
|
||||
s = (noise.shape[0], noise.shape[1], noise.shape[2], 1, 1)
|
||||
|
||||
if self.train_config.random_noise_multiplier > 0.0:
|
||||
|
||||
# do it on a per batch item, per channel basis
|
||||
noise_multiplier = 1 + torch.randn(
|
||||
s,
|
||||
device=noise.device,
|
||||
dtype=noise.dtype
|
||||
) * self.train_config.random_noise_multiplier
|
||||
|
||||
noise = noise * noise_multiplier
|
||||
|
||||
if self.train_config.random_noise_shift > 0.0:
|
||||
# get random noise -1 to 1
|
||||
noise_shift = torch.randn(
|
||||
s,
|
||||
device=noise.device,
|
||||
dtype=noise.dtype
|
||||
) * self.train_config.random_noise_shift
|
||||
# add to noise
|
||||
noise += noise_shift
|
||||
|
||||
latent_multiplier = self.train_config.latent_multiplier
|
||||
|
||||
|
||||
Reference in New Issue
Block a user