diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 19c72472..1492b19f 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -23,7 +23,7 @@ from toolkit.metadata import get_meta_for_safetensors from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses from toolkit.train_tools import get_torch_dtype -from diffusers import AutoencoderKL +from diffusers import AutoencoderKL, AutoencoderTiny from tqdm import tqdm import math import torchvision.utils @@ -94,6 +94,9 @@ class TrainVAEProcess(BaseTrainProcess): self.dropout = self.get_conf('dropout', 0.0, as_type=float) self.train_encoder = self.get_conf('train_encoder', False, as_type=bool) self.random_scaling = self.get_conf('random_scaling', False, as_type=bool) + self.vae_type = self.get_conf('vae_type', 'AutoencoderKL', as_type=str) # AutoencoderKL or AutoencoderTiny + + self.VaeClass = AutoencoderKL if self.vae_type == 'AutoencoderKL' else AutoencoderTiny if not self.train_encoder: # remove losses that only target encoder @@ -407,7 +410,8 @@ class TrainVAEProcess(BaseTrainProcess): input_img = img img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) img = img - latent = self.vae.encode(img).latent_dist.sample() + # latent = self.vae.encode(img).latent_dist.sample() + latent = self.vae.encode(img, return_dict=False)[0] latent_img = latent.clone() bs, ch, h, w = latent_img.shape @@ -492,9 +496,9 @@ class TrainVAEProcess(BaseTrainProcess): self.print(f" - Loading VAE: {path_to_load}") if self.vae is None: if path_to_load is not None: - self.vae = AutoencoderKL.from_pretrained(path_to_load) + self.vae = self.VaeClass.from_pretrained(path_to_load) elif self.vae_config is not None: - self.vae = AutoencoderKL(**self.vae_config) + self.vae = self.VaeClass(**self.vae_config) else: raise ValueError('vae_path or ae_config must be specified') @@ -511,7 +515,7 @@ class TrainVAEProcess(BaseTrainProcess): if self.target_latent_vae_path is not None: self.print(f"Loading target latent VAE from {self.target_latent_vae_path}") self.target_latent_vae = AutoencoderKL.from_pretrained(self.target_latent_vae_path) - self.target_latent_vae.to(self.device, dtype=self.torch_dtype) + self.target_latent_vae.to(self.device, dtype=torch.float32) self.target_latent_vae.eval() self.target_vae_scale_factor = 2 ** (len(self.target_latent_vae.config['block_out_channels']) - 1) else: @@ -664,20 +668,26 @@ class TrainVAEProcess(BaseTrainProcess): target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor target_input_size = (int(batch.shape[2] * target_input_scale), int(batch.shape[3] * target_input_scale)) # resize to target input size - target_input_batch = Resize(target_input_size)(batch) + target_input_batch = Resize(target_input_size)(batch).to(self.device, dtype=torch.float32) target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach() + target_latent = target_latent.to(self.device, dtype=self.torch_dtype) # forward pass # grad only if eq_vae with torch.set_grad_enabled(self.train_encoder): - dgd = self.vae.encode(batch).latent_dist - mu, logvar = dgd.mean, dgd.logvar - latents = dgd.sample() + if self.vae_type == 'AutoencoderTiny': + # AutoencoderTiny cannot do latent distribution sampling + latents = self.vae.encode(batch, return_dict=False)[0] + mu, logvar = None, None + else: + dgd = self.vae.encode(batch).latent_dist + mu, logvar = dgd.mean, dgd.logvar + latents = dgd.sample() if target_latent is not None: # forward_latents = target_latent.detach() - lat_mse_loss = self.get_mse_loss(target_latent, latents) + lat_mse_loss = self.get_mse_loss(target_latent.float(), latents.float()) latents = target_latent.detach() forward_latents = target_latent.detach() diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 47f2bfef..54f5a98c 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -16,6 +16,7 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset from tqdm import tqdm import albumentations as A +from toolkit import image_utils from toolkit.buckets import get_bucket_for_image_size, BucketResolution from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin @@ -100,8 +101,13 @@ class ImageDataset(Dataset, CaptionMixin): new_file_list = [] bad_count = 0 for file in tqdm(self.file_list): - img = Image.open(file) - if int(min(img.size) * self.scale) >= self.resolution: + try: + w, h = image_utils.get_image_size(file) + except image_utils.UnknownImageFormat: + img = exif_transpose(Image.open(file)) + w, h = img.size + # img = Image.open(file) + if int(min([w, h]) * self.scale) >= self.resolution: new_file_list.append(file) else: bad_count += 1