Allow finetuning tiny autoencoder in vae trainer

This commit is contained in:
Jaret Burkett
2025-07-16 07:13:30 -06:00
parent 1930c3edea
commit e5ed450dc7
2 changed files with 28 additions and 12 deletions

View File

@@ -23,7 +23,7 @@ from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype
from diffusers import AutoencoderKL
from diffusers import AutoencoderKL, AutoencoderTiny
from tqdm import tqdm
import math
import torchvision.utils
@@ -94,6 +94,9 @@ class TrainVAEProcess(BaseTrainProcess):
self.dropout = self.get_conf('dropout', 0.0, as_type=float)
self.train_encoder = self.get_conf('train_encoder', False, as_type=bool)
self.random_scaling = self.get_conf('random_scaling', False, as_type=bool)
self.vae_type = self.get_conf('vae_type', 'AutoencoderKL', as_type=str) # AutoencoderKL or AutoencoderTiny
self.VaeClass = AutoencoderKL if self.vae_type == 'AutoencoderKL' else AutoencoderTiny
if not self.train_encoder:
# remove losses that only target encoder
@@ -407,7 +410,8 @@ class TrainVAEProcess(BaseTrainProcess):
input_img = img
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
img = img
latent = self.vae.encode(img).latent_dist.sample()
# latent = self.vae.encode(img).latent_dist.sample()
latent = self.vae.encode(img, return_dict=False)[0]
latent_img = latent.clone()
bs, ch, h, w = latent_img.shape
@@ -492,9 +496,9 @@ class TrainVAEProcess(BaseTrainProcess):
self.print(f" - Loading VAE: {path_to_load}")
if self.vae is None:
if path_to_load is not None:
self.vae = AutoencoderKL.from_pretrained(path_to_load)
self.vae = self.VaeClass.from_pretrained(path_to_load)
elif self.vae_config is not None:
self.vae = AutoencoderKL(**self.vae_config)
self.vae = self.VaeClass(**self.vae_config)
else:
raise ValueError('vae_path or ae_config must be specified')
@@ -511,7 +515,7 @@ class TrainVAEProcess(BaseTrainProcess):
if self.target_latent_vae_path is not None:
self.print(f"Loading target latent VAE from {self.target_latent_vae_path}")
self.target_latent_vae = AutoencoderKL.from_pretrained(self.target_latent_vae_path)
self.target_latent_vae.to(self.device, dtype=self.torch_dtype)
self.target_latent_vae.to(self.device, dtype=torch.float32)
self.target_latent_vae.eval()
self.target_vae_scale_factor = 2 ** (len(self.target_latent_vae.config['block_out_channels']) - 1)
else:
@@ -664,20 +668,26 @@ class TrainVAEProcess(BaseTrainProcess):
target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor
target_input_size = (int(batch.shape[2] * target_input_scale), int(batch.shape[3] * target_input_scale))
# resize to target input size
target_input_batch = Resize(target_input_size)(batch)
target_input_batch = Resize(target_input_size)(batch).to(self.device, dtype=torch.float32)
target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach()
target_latent = target_latent.to(self.device, dtype=self.torch_dtype)
# forward pass
# grad only if eq_vae
with torch.set_grad_enabled(self.train_encoder):
dgd = self.vae.encode(batch).latent_dist
mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample()
if self.vae_type == 'AutoencoderTiny':
# AutoencoderTiny cannot do latent distribution sampling
latents = self.vae.encode(batch, return_dict=False)[0]
mu, logvar = None, None
else:
dgd = self.vae.encode(batch).latent_dist
mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample()
if target_latent is not None:
# forward_latents = target_latent.detach()
lat_mse_loss = self.get_mse_loss(target_latent, latents)
lat_mse_loss = self.get_mse_loss(target_latent.float(), latents.float())
latents = target_latent.detach()
forward_latents = target_latent.detach()

View File

@@ -16,6 +16,7 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
from tqdm import tqdm
import albumentations as A
from toolkit import image_utils
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin
@@ -100,8 +101,13 @@ class ImageDataset(Dataset, CaptionMixin):
new_file_list = []
bad_count = 0
for file in tqdm(self.file_list):
img = Image.open(file)
if int(min(img.size) * self.scale) >= self.resolution:
try:
w, h = image_utils.get_image_size(file)
except image_utils.UnknownImageFormat:
img = exif_transpose(Image.open(file))
w, h = img.size
# img = Image.open(file)
if int(min([w, h]) * self.scale) >= self.resolution:
new_file_list.append(file)
else:
bad_count += 1