mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 11:41:35 +00:00
Numerous fixes for time sampling. Still not perfect
This commit is contained in:
@@ -666,7 +666,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
unconditional_imgs = batch.unconditional_tensor
|
||||
unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype)
|
||||
unconditional_latents = self.sd.encode_images(unconditional_imgs)
|
||||
batch.unconditional_latents = unconditional_latents
|
||||
batch.unconditional_latents = unconditional_latents * self.train_config.latent_multiplier
|
||||
|
||||
unaugmented_latents = None
|
||||
if self.train_config.loss_target == 'differential_noise':
|
||||
@@ -715,36 +715,41 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
orig_timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if self.train_config.content_or_style == 'content':
|
||||
timesteps = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'style':
|
||||
timesteps = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
timesteps = value_map(
|
||||
timesteps,
|
||||
timestep_indices = value_map(
|
||||
timestep_indices,
|
||||
0,
|
||||
self.sd.noise_scheduler.config['num_train_timesteps'] - 1,
|
||||
min_noise_steps,
|
||||
max_noise_steps
|
||||
max_noise_steps - 1
|
||||
)
|
||||
timesteps = timesteps.long().clamp(
|
||||
timestep_indices = timestep_indices.long().clamp(
|
||||
min_noise_steps + 1,
|
||||
max_noise_steps - 1
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
if min_noise_steps == max_noise_steps:
|
||||
timesteps = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
||||
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
||||
else:
|
||||
timesteps = torch.randint(
|
||||
min_noise_steps,
|
||||
max_noise_steps,
|
||||
# todo, some schedulers use indices, otheres use timesteps. Not sure what to do here
|
||||
timestep_indices = torch.randint(
|
||||
min_noise_steps + 1,
|
||||
max_noise_steps - 1,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
timestep_indices = timestep_indices.long()
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||
|
||||
# convert the timestep_indices to a timestep
|
||||
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||
timesteps = torch.stack(timesteps, dim=0)
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
height=latents.shape[2],
|
||||
@@ -765,9 +770,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
noise = noise * noise_multiplier
|
||||
|
||||
img_multiplier = self.train_config.img_multiplier
|
||||
latents = latents * self.train_config.latent_multiplier
|
||||
|
||||
latents = latents * img_multiplier
|
||||
if batch.unconditional_latents is not None:
|
||||
batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier
|
||||
|
||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user