mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Moved SD batch processing to a shared method and added it for use in slider training. Still testing if it affects quality over sampling
This commit is contained in:
@@ -36,67 +36,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.text_encoder.train()
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
with torch.no_grad():
|
||||
imgs, prompts, dataset_config = batch
|
||||
|
||||
# convert the 0 or 1 for is reg to a bool list
|
||||
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
|
||||
if isinstance(is_reg_list, torch.Tensor):
|
||||
is_reg_list = is_reg_list.numpy().tolist()
|
||||
is_reg_list = [bool(x) for x in is_reg_list]
|
||||
|
||||
conditioned_prompts = []
|
||||
|
||||
for prompt, is_reg in zip(prompts, is_reg_list):
|
||||
|
||||
# make sure the embedding is in the prompts
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt,
|
||||
expand_token=True,
|
||||
add_if_not_present=True,
|
||||
)
|
||||
|
||||
# make sure trigger is in the prompts if not a regularization run
|
||||
if self.trigger_word is not None and not is_reg:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
prompt,
|
||||
add_if_not_present=True,
|
||||
)
|
||||
conditioned_prompts.append(prompt)
|
||||
|
||||
batch_size = imgs.shape[0]
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
imgs = imgs.to(self.device_torch, dtype=dtype)
|
||||
latents = self.sd.encode_images(imgs)
|
||||
|
||||
noise_scheduler = self.sd.noise_scheduler
|
||||
optimizer = self.optimizer
|
||||
lr_scheduler = self.lr_scheduler
|
||||
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.train_config.max_denoising_steps, device=self.device_torch
|
||||
)
|
||||
|
||||
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=imgs.shape[2],
|
||||
pixel_width=imgs.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noise.requires_grad = False
|
||||
|
||||
flush()
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
@@ -135,7 +76,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
@@ -144,7 +85,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -153,9 +94,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
lr_scheduler.step()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.embedding is not None:
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
|
||||
Reference in New Issue
Block a user