Moved SD batch processing to a shared method and added it for use in slider training. Still testing if it affects quality over sampling

This commit is contained in:
Jaret Burkett
2023-08-26 08:55:00 -06:00
parent aeaca13d69
commit 3367ab6b2c
6 changed files with 178 additions and 115 deletions

View File

@@ -283,6 +283,66 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
print("load_weights not implemented for non-network models")
def process_general_training_batch(self, batch):
with torch.no_grad():
imgs, prompts, dataset_config = batch
# convert the 0 or 1 for is reg to a bool list
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
if isinstance(is_reg_list, torch.Tensor):
is_reg_list = is_reg_list.numpy().tolist()
is_reg_list = [bool(x) for x in is_reg_list]
conditioned_prompts = []
for prompt, is_reg in zip(prompts, is_reg_list):
# make sure the embedding is in the prompts
if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt(
prompt,
expand_token=True,
add_if_not_present=True,
)
# make sure trigger is in the prompts if not a regularization run
if self.trigger_word is not None and not is_reg:
prompt = self.sd.inject_trigger_into_prompt(
prompt,
add_if_not_present=True,
)
conditioned_prompts.append(prompt)
batch_size = imgs.shape[0]
dtype = get_torch_dtype(self.train_config.dtype)
imgs = imgs.to(self.device_torch, dtype=dtype)
latents = self.sd.encode_images(imgs)
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch)
timesteps = timesteps.long()
# get noise
noise = self.sd.get_latent_noise(
pixel_height=imgs.shape[2],
pixel_width=imgs.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
# remove grads for these
noisy_latents.requires_grad = False
noise.requires_grad = False
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
def run(self):
# run base process run
BaseTrainProcess.run(self)