mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-02 20:07:53 +00:00
Bug fixes
This commit is contained in:
@@ -557,6 +557,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def hook_before_train_loop(self):
|
||||
pass
|
||||
|
||||
def ensure_params_requires_grad(self):
|
||||
# get param groups
|
||||
for group in self.optimizer.param_groups:
|
||||
for param in group['params']:
|
||||
param.requires_grad = True
|
||||
|
||||
def setup_ema(self):
|
||||
if self.train_config.ema_config.use_ema:
|
||||
# our params are in groups. We need them as a single iterable
|
||||
@@ -1535,6 +1541,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# print(f"Compiling Model")
|
||||
# torch.compile(self.sd.unet, dynamic=True)
|
||||
|
||||
# make sure all params require grad
|
||||
self.ensure_params_requires_grad()
|
||||
|
||||
|
||||
###################################################################
|
||||
# TRAIN LOOP
|
||||
###################################################################
|
||||
@@ -1652,6 +1662,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.train_config.free_u:
|
||||
self.sd.pipeline.disable_freeu()
|
||||
self.sample(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if is_save_step:
|
||||
@@ -1659,6 +1670,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.progress_bar.pause()
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
self.ensure_params_requires_grad()
|
||||
self.progress_bar.unpause()
|
||||
|
||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||
|
||||
@@ -47,7 +47,7 @@ class GenerateConfig:
|
||||
|
||||
self.random_prompts = kwargs.get('random_prompts', False)
|
||||
self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1)
|
||||
self.max_images = kwargs.get('max_prompts', 10000)
|
||||
self.max_images = kwargs.get('max_images', 10000)
|
||||
|
||||
if self.random_prompts:
|
||||
self.prompts = []
|
||||
|
||||
@@ -25,3 +25,4 @@ python-dotenv
|
||||
bitsandbytes
|
||||
hf_transfer
|
||||
lpips
|
||||
pytorch_fid
|
||||
113
testing/test_vae.py
Normal file
113
testing/test_vae.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import argparse
|
||||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torchvision.transforms import Resize, ToTensor
|
||||
from diffusers import AutoencoderKL
|
||||
from pytorch_fid import fid_score
|
||||
from skimage.metrics import peak_signal_noise_ratio as psnr
|
||||
import lpips
|
||||
from tqdm import tqdm
|
||||
from torchvision import transforms
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def load_images(folder_path):
|
||||
images = []
|
||||
for filename in os.listdir(folder_path):
|
||||
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||||
img_path = os.path.join(folder_path, filename)
|
||||
images.append(img_path)
|
||||
return images
|
||||
|
||||
|
||||
def paramiter_count(model):
|
||||
state_dict = model.state_dict()
|
||||
paramiter_count = 0
|
||||
for key in state_dict:
|
||||
paramiter_count += torch.numel(state_dict[key])
|
||||
return int(paramiter_count)
|
||||
|
||||
|
||||
def calculate_metrics(vae, images, max_imgs=-1):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
vae = vae.to(device)
|
||||
lpips_model = lpips.LPIPS(net='alex').to(device)
|
||||
|
||||
rfid_scores = []
|
||||
psnr_scores = []
|
||||
lpips_scores = []
|
||||
|
||||
# transform = transforms.Compose([
|
||||
# transforms.Resize(256, antialias=True),
|
||||
# transforms.CenterCrop(256)
|
||||
# ])
|
||||
# needs values between -1 and 1
|
||||
to_tensor = ToTensor()
|
||||
|
||||
if max_imgs > 0 and len(images) > max_imgs:
|
||||
images = images[:max_imgs]
|
||||
|
||||
for img_path in tqdm(images):
|
||||
try:
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
# img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device)
|
||||
img_tensor = to_tensor(img).unsqueeze(0).to(device)
|
||||
img_tensor = 2 * img_tensor - 1
|
||||
# if width or height is not divisible by 8, crop it
|
||||
if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0:
|
||||
img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {img_path}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample
|
||||
|
||||
# Calculate rFID
|
||||
# rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed)
|
||||
# rfid_scores.append(rfid)
|
||||
|
||||
# Calculate PSNR
|
||||
psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy())
|
||||
psnr_scores.append(psnr_val)
|
||||
|
||||
# Calculate LPIPS
|
||||
lpips_val = lpips_model(img_tensor, reconstructed).item()
|
||||
lpips_scores.append(lpips_val)
|
||||
|
||||
# avg_rfid = sum(rfid_scores) / len(rfid_scores)
|
||||
avg_rfid = 0
|
||||
avg_psnr = sum(psnr_scores) / len(psnr_scores)
|
||||
avg_lpips = sum(lpips_scores) / len(lpips_scores)
|
||||
|
||||
return avg_rfid, avg_psnr, avg_lpips
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions")
|
||||
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
|
||||
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
|
||||
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.isfile(args.vae_path):
|
||||
vae = AutoencoderKL.from_single_file(args.vae_path)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path)
|
||||
vae.eval()
|
||||
vae = vae.to(device)
|
||||
print(f"Model has {paramiter_count(vae)} parameters")
|
||||
images = load_images(args.image_folder)
|
||||
|
||||
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
|
||||
|
||||
# print(f"Average rFID: {avg_rfid}")
|
||||
print(f"Average PSNR: {avg_psnr}")
|
||||
print(f"Average LPIPS: {avg_lpips}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -403,11 +403,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
|
||||
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
|
||||
|
||||
transformer.pos_embed.orig_forward = transformer.pos_embed.forward
|
||||
transformer.proj_out.orig_forward = transformer.proj_out.forward
|
||||
|
||||
transformer.pos_embed.forward = self.transformer_pos_embed.forward
|
||||
transformer.proj_out.forward = self.transformer_proj_out.forward
|
||||
transformer.pos_embed = self.transformer_pos_embed
|
||||
transformer.proj_out = self.transformer_proj_out
|
||||
|
||||
else:
|
||||
unet: UNet2DConditionModel = unet
|
||||
@@ -417,10 +414,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
# clone these and replace their forwards with ours
|
||||
self.unet_conv_in = copy.deepcopy(unet_conv_in)
|
||||
self.unet_conv_out = copy.deepcopy(unet_conv_out)
|
||||
unet.conv_in.orig_forward = unet_conv_in.forward
|
||||
unet_conv_out.orig_forward = unet_conv_out.forward
|
||||
unet.conv_in.forward = self.unet_conv_in.forward
|
||||
unet.conv_out.forward = self.unet_conv_out.forward
|
||||
unet.conv_in = self.unet_conv_in
|
||||
unet.conv_out = self.unet_conv_out
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
# call Lora prepare_optimizer_params
|
||||
|
||||
@@ -123,6 +123,8 @@ def get_sampler(
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 3.0
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Sampler {sampler} not supported")
|
||||
|
||||
|
||||
config = copy.deepcopy(config_to_use)
|
||||
|
||||
Reference in New Issue
Block a user