From acb06d6ff361e2b2f4cda4b426976e6247eeeda1 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 3 Jul 2024 10:56:34 -0600 Subject: [PATCH] Bug fixes --- jobs/process/BaseSDTrainProcess.py | 12 +++ jobs/process/GenerateProcess.py | 2 +- requirements.txt | 1 + testing/test_vae.py | 113 +++++++++++++++++++++++++++++ toolkit/lora_special.py | 13 +--- toolkit/sampler.py | 2 + 6 files changed, 133 insertions(+), 10 deletions(-) create mode 100644 testing/test_vae.py diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 20d7d322..b863dec9 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -557,6 +557,12 @@ class BaseSDTrainProcess(BaseTrainProcess): def hook_before_train_loop(self): pass + def ensure_params_requires_grad(self): + # get param groups + for group in self.optimizer.param_groups: + for param in group['params']: + param.requires_grad = True + def setup_ema(self): if self.train_config.ema_config.use_ema: # our params are in groups. We need them as a single iterable @@ -1535,6 +1541,10 @@ class BaseSDTrainProcess(BaseTrainProcess): # print(f"Compiling Model") # torch.compile(self.sd.unet, dynamic=True) + # make sure all params require grad + self.ensure_params_requires_grad() + + ################################################################### # TRAIN LOOP ################################################################### @@ -1652,6 +1662,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.train_config.free_u: self.sd.pipeline.disable_freeu() self.sample(self.step_num) + self.ensure_params_requires_grad() self.progress_bar.unpause() if is_save_step: @@ -1659,6 +1670,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.progress_bar.pause() self.print(f"Saving at step {self.step_num}") self.save(self.step_num) + self.ensure_params_requires_grad() self.progress_bar.unpause() if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0: diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index 95bb7427..4a9ff8c5 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -47,7 +47,7 @@ class GenerateConfig: self.random_prompts = kwargs.get('random_prompts', False) self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1) - self.max_images = kwargs.get('max_prompts', 10000) + self.max_images = kwargs.get('max_images', 10000) if self.random_prompts: self.prompts = [] diff --git a/requirements.txt b/requirements.txt index 347cc0e1..ba577a8b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,4 @@ python-dotenv bitsandbytes hf_transfer lpips +pytorch_fid \ No newline at end of file diff --git a/testing/test_vae.py b/testing/test_vae.py new file mode 100644 index 00000000..44b31f63 --- /dev/null +++ b/testing/test_vae.py @@ -0,0 +1,113 @@ +import argparse +import os +from PIL import Image +import torch +from torchvision.transforms import Resize, ToTensor +from diffusers import AutoencoderKL +from pytorch_fid import fid_score +from skimage.metrics import peak_signal_noise_ratio as psnr +import lpips +from tqdm import tqdm +from torchvision import transforms + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def load_images(folder_path): + images = [] + for filename in os.listdir(folder_path): + if filename.lower().endswith(('.png', '.jpg', '.jpeg')): + img_path = os.path.join(folder_path, filename) + images.append(img_path) + return images + + +def paramiter_count(model): + state_dict = model.state_dict() + paramiter_count = 0 + for key in state_dict: + paramiter_count += torch.numel(state_dict[key]) + return int(paramiter_count) + + +def calculate_metrics(vae, images, max_imgs=-1): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + vae = vae.to(device) + lpips_model = lpips.LPIPS(net='alex').to(device) + + rfid_scores = [] + psnr_scores = [] + lpips_scores = [] + + # transform = transforms.Compose([ + # transforms.Resize(256, antialias=True), + # transforms.CenterCrop(256) + # ]) + # needs values between -1 and 1 + to_tensor = ToTensor() + + if max_imgs > 0 and len(images) > max_imgs: + images = images[:max_imgs] + + for img_path in tqdm(images): + try: + img = Image.open(img_path).convert('RGB') + # img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device) + img_tensor = to_tensor(img).unsqueeze(0).to(device) + img_tensor = 2 * img_tensor - 1 + # if width or height is not divisible by 8, crop it + if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0: + img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8] + + except Exception as e: + print(f"Error processing {img_path}: {e}") + continue + + + with torch.no_grad(): + reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample + + # Calculate rFID + # rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed) + # rfid_scores.append(rfid) + + # Calculate PSNR + psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy()) + psnr_scores.append(psnr_val) + + # Calculate LPIPS + lpips_val = lpips_model(img_tensor, reconstructed).item() + lpips_scores.append(lpips_val) + + # avg_rfid = sum(rfid_scores) / len(rfid_scores) + avg_rfid = 0 + avg_psnr = sum(psnr_scores) / len(psnr_scores) + avg_lpips = sum(lpips_scores) / len(lpips_scores) + + return avg_rfid, avg_psnr, avg_lpips + + +def main(): + parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions") + parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model") + parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images") + parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.") + args = parser.parse_args() + + if os.path.isfile(args.vae_path): + vae = AutoencoderKL.from_single_file(args.vae_path) + else: + vae = AutoencoderKL.from_pretrained(args.vae_path) + vae.eval() + vae = vae.to(device) + print(f"Model has {paramiter_count(vae)} parameters") + images = load_images(args.image_folder) + + avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs) + + # print(f"Average rFID: {avg_rfid}") + print(f"Average PSNR: {avg_psnr}") + print(f"Average LPIPS: {avg_lpips}") + + +if __name__ == "__main__": + main() diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 9449db47..bcb2d63d 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -403,11 +403,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed) self.transformer_proj_out = copy.deepcopy(transformer.proj_out) - transformer.pos_embed.orig_forward = transformer.pos_embed.forward - transformer.proj_out.orig_forward = transformer.proj_out.forward - - transformer.pos_embed.forward = self.transformer_pos_embed.forward - transformer.proj_out.forward = self.transformer_proj_out.forward + transformer.pos_embed = self.transformer_pos_embed + transformer.proj_out = self.transformer_proj_out else: unet: UNet2DConditionModel = unet @@ -417,10 +414,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): # clone these and replace their forwards with ours self.unet_conv_in = copy.deepcopy(unet_conv_in) self.unet_conv_out = copy.deepcopy(unet_conv_out) - unet.conv_in.orig_forward = unet_conv_in.forward - unet_conv_out.orig_forward = unet_conv_out.forward - unet.conv_in.forward = self.unet_conv_in.forward - unet.conv_out.forward = self.unet_conv_out.forward + unet.conv_in = self.unet_conv_in + unet.conv_out = self.unet_conv_out def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): # call Lora prepare_optimizer_params diff --git a/toolkit/sampler.py b/toolkit/sampler.py index e6c6e32e..6d42d94f 100644 --- a/toolkit/sampler.py +++ b/toolkit/sampler.py @@ -123,6 +123,8 @@ def get_sampler( "num_train_timesteps": 1000, "shift": 3.0 } + else: + raise ValueError(f"Sampler {sampler} not supported") config = copy.deepcopy(config_to_use)