mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Allow finetuning tiny autoencoder in vae trainer
This commit is contained in:
@@ -23,7 +23,7 @@ from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
import torchvision.utils
|
||||
@@ -94,6 +94,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.dropout = self.get_conf('dropout', 0.0, as_type=float)
|
||||
self.train_encoder = self.get_conf('train_encoder', False, as_type=bool)
|
||||
self.random_scaling = self.get_conf('random_scaling', False, as_type=bool)
|
||||
self.vae_type = self.get_conf('vae_type', 'AutoencoderKL', as_type=str) # AutoencoderKL or AutoencoderTiny
|
||||
|
||||
self.VaeClass = AutoencoderKL if self.vae_type == 'AutoencoderKL' else AutoencoderTiny
|
||||
|
||||
if not self.train_encoder:
|
||||
# remove losses that only target encoder
|
||||
@@ -407,7 +410,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
input_img = img
|
||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
||||
img = img
|
||||
latent = self.vae.encode(img).latent_dist.sample()
|
||||
# latent = self.vae.encode(img).latent_dist.sample()
|
||||
latent = self.vae.encode(img, return_dict=False)[0]
|
||||
|
||||
latent_img = latent.clone()
|
||||
bs, ch, h, w = latent_img.shape
|
||||
@@ -492,9 +496,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.print(f" - Loading VAE: {path_to_load}")
|
||||
if self.vae is None:
|
||||
if path_to_load is not None:
|
||||
self.vae = AutoencoderKL.from_pretrained(path_to_load)
|
||||
self.vae = self.VaeClass.from_pretrained(path_to_load)
|
||||
elif self.vae_config is not None:
|
||||
self.vae = AutoencoderKL(**self.vae_config)
|
||||
self.vae = self.VaeClass(**self.vae_config)
|
||||
else:
|
||||
raise ValueError('vae_path or ae_config must be specified')
|
||||
|
||||
@@ -511,7 +515,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if self.target_latent_vae_path is not None:
|
||||
self.print(f"Loading target latent VAE from {self.target_latent_vae_path}")
|
||||
self.target_latent_vae = AutoencoderKL.from_pretrained(self.target_latent_vae_path)
|
||||
self.target_latent_vae.to(self.device, dtype=self.torch_dtype)
|
||||
self.target_latent_vae.to(self.device, dtype=torch.float32)
|
||||
self.target_latent_vae.eval()
|
||||
self.target_vae_scale_factor = 2 ** (len(self.target_latent_vae.config['block_out_channels']) - 1)
|
||||
else:
|
||||
@@ -664,20 +668,26 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor
|
||||
target_input_size = (int(batch.shape[2] * target_input_scale), int(batch.shape[3] * target_input_scale))
|
||||
# resize to target input size
|
||||
target_input_batch = Resize(target_input_size)(batch)
|
||||
target_input_batch = Resize(target_input_size)(batch).to(self.device, dtype=torch.float32)
|
||||
target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach()
|
||||
target_latent = target_latent.to(self.device, dtype=self.torch_dtype)
|
||||
|
||||
|
||||
# forward pass
|
||||
# grad only if eq_vae
|
||||
with torch.set_grad_enabled(self.train_encoder):
|
||||
dgd = self.vae.encode(batch).latent_dist
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
if self.vae_type == 'AutoencoderTiny':
|
||||
# AutoencoderTiny cannot do latent distribution sampling
|
||||
latents = self.vae.encode(batch, return_dict=False)[0]
|
||||
mu, logvar = None, None
|
||||
else:
|
||||
dgd = self.vae.encode(batch).latent_dist
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
|
||||
if target_latent is not None:
|
||||
# forward_latents = target_latent.detach()
|
||||
lat_mse_loss = self.get_mse_loss(target_latent, latents)
|
||||
lat_mse_loss = self.get_mse_loss(target_latent.float(), latents.float())
|
||||
latents = target_latent.detach()
|
||||
forward_latents = target_latent.detach()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
||||
from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin, ControlCachingMixin
|
||||
@@ -100,8 +101,13 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
new_file_list = []
|
||||
bad_count = 0
|
||||
for file in tqdm(self.file_list):
|
||||
img = Image.open(file)
|
||||
if int(min(img.size) * self.scale) >= self.resolution:
|
||||
try:
|
||||
w, h = image_utils.get_image_size(file)
|
||||
except image_utils.UnknownImageFormat:
|
||||
img = exif_transpose(Image.open(file))
|
||||
w, h = img.size
|
||||
# img = Image.open(file)
|
||||
if int(min([w, h]) * self.scale) >= self.resolution:
|
||||
new_file_list.append(file)
|
||||
else:
|
||||
bad_count += 1
|
||||
|
||||
Reference in New Issue
Block a user