Various features and fixes. Too much brain fog to do a proper description

This commit is contained in:
Jaret Burkett
2024-07-18 07:34:14 -06:00
parent 58dffd43a8
commit 11e426fdf1
6 changed files with 119 additions and 25 deletions

View File

@@ -356,7 +356,7 @@ class SDTrainer(BaseSDTrainProcess):
# we have to encode images into latents for now
# we also denoise as the unaugmented tensor is not a noisy diffirental
with torch.no_grad():
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype)
unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier
target = unaugmented_latents.detach()
@@ -907,6 +907,17 @@ class SDTrainer(BaseSDTrainProcess):
self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)
# sanity check
if self.sd.vae.dtype != self.sd.vae_torch_dtype:
self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype)
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
if encoder.dtype != self.sd.te_torch_dtype:
encoder.to(self.sd.te_torch_dtype)
else:
if self.sd.text_encoder.dtype != self.sd.te_torch_dtype:
self.sd.text_encoder.to(self.sd.te_torch_dtype)
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
if self.train_config.do_cfg or self.train_config.do_random_cfg:
# pick random negative prompts