Bug fixes

This commit is contained in:
Jaret Burkett
2024-07-03 10:56:34 -06:00
parent bb57623a35
commit acb06d6ff3
6 changed files with 133 additions and 10 deletions

View File

@@ -557,6 +557,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
def hook_before_train_loop(self):
pass
def ensure_params_requires_grad(self):
# get param groups
for group in self.optimizer.param_groups:
for param in group['params']:
param.requires_grad = True
def setup_ema(self):
if self.train_config.ema_config.use_ema:
# our params are in groups. We need them as a single iterable
@@ -1535,6 +1541,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# print(f"Compiling Model")
# torch.compile(self.sd.unet, dynamic=True)
# make sure all params require grad
self.ensure_params_requires_grad()
###################################################################
# TRAIN LOOP
###################################################################
@@ -1652,6 +1662,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.train_config.free_u:
self.sd.pipeline.disable_freeu()
self.sample(self.step_num)
self.ensure_params_requires_grad()
self.progress_bar.unpause()
if is_save_step:
@@ -1659,6 +1670,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar.pause()
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
self.ensure_params_requires_grad()
self.progress_bar.unpause()
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:

View File

@@ -47,7 +47,7 @@ class GenerateConfig:
self.random_prompts = kwargs.get('random_prompts', False)
self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1)
self.max_images = kwargs.get('max_prompts', 10000)
self.max_images = kwargs.get('max_images', 10000)
if self.random_prompts:
self.prompts = []

View File

@@ -25,3 +25,4 @@ python-dotenv
bitsandbytes
hf_transfer
lpips
pytorch_fid

113
testing/test_vae.py Normal file
View File

@@ -0,0 +1,113 @@
import argparse
import os
from PIL import Image
import torch
from torchvision.transforms import Resize, ToTensor
from diffusers import AutoencoderKL
from pytorch_fid import fid_score
from skimage.metrics import peak_signal_noise_ratio as psnr
import lpips
from tqdm import tqdm
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_images(folder_path):
images = []
for filename in os.listdir(folder_path):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(folder_path, filename)
images.append(img_path)
return images
def paramiter_count(model):
state_dict = model.state_dict()
paramiter_count = 0
for key in state_dict:
paramiter_count += torch.numel(state_dict[key])
return int(paramiter_count)
def calculate_metrics(vae, images, max_imgs=-1):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
lpips_model = lpips.LPIPS(net='alex').to(device)
rfid_scores = []
psnr_scores = []
lpips_scores = []
# transform = transforms.Compose([
# transforms.Resize(256, antialias=True),
# transforms.CenterCrop(256)
# ])
# needs values between -1 and 1
to_tensor = ToTensor()
if max_imgs > 0 and len(images) > max_imgs:
images = images[:max_imgs]
for img_path in tqdm(images):
try:
img = Image.open(img_path).convert('RGB')
# img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device)
img_tensor = to_tensor(img).unsqueeze(0).to(device)
img_tensor = 2 * img_tensor - 1
# if width or height is not divisible by 8, crop it
if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0:
img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8]
except Exception as e:
print(f"Error processing {img_path}: {e}")
continue
with torch.no_grad():
reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample
# Calculate rFID
# rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed)
# rfid_scores.append(rfid)
# Calculate PSNR
psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy())
psnr_scores.append(psnr_val)
# Calculate LPIPS
lpips_val = lpips_model(img_tensor, reconstructed).item()
lpips_scores.append(lpips_val)
# avg_rfid = sum(rfid_scores) / len(rfid_scores)
avg_rfid = 0
avg_psnr = sum(psnr_scores) / len(psnr_scores)
avg_lpips = sum(lpips_scores) / len(lpips_scores)
return avg_rfid, avg_psnr, avg_lpips
def main():
parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions")
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
args = parser.parse_args()
if os.path.isfile(args.vae_path):
vae = AutoencoderKL.from_single_file(args.vae_path)
else:
vae = AutoencoderKL.from_pretrained(args.vae_path)
vae.eval()
vae = vae.to(device)
print(f"Model has {paramiter_count(vae)} parameters")
images = load_images(args.image_folder)
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
# print(f"Average rFID: {avg_rfid}")
print(f"Average PSNR: {avg_psnr}")
print(f"Average LPIPS: {avg_lpips}")
if __name__ == "__main__":
main()

View File

@@ -403,11 +403,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
transformer.pos_embed.orig_forward = transformer.pos_embed.forward
transformer.proj_out.orig_forward = transformer.proj_out.forward
transformer.pos_embed.forward = self.transformer_pos_embed.forward
transformer.proj_out.forward = self.transformer_proj_out.forward
transformer.pos_embed = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
else:
unet: UNet2DConditionModel = unet
@@ -417,10 +414,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
# clone these and replace their forwards with ours
self.unet_conv_in = copy.deepcopy(unet_conv_in)
self.unet_conv_out = copy.deepcopy(unet_conv_out)
unet.conv_in.orig_forward = unet_conv_in.forward
unet_conv_out.orig_forward = unet_conv_out.forward
unet.conv_in.forward = self.unet_conv_in.forward
unet.conv_out.forward = self.unet_conv_out.forward
unet.conv_in = self.unet_conv_in
unet.conv_out = self.unet_conv_out
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
# call Lora prepare_optimizer_params

View File

@@ -123,6 +123,8 @@ def get_sampler(
"num_train_timesteps": 1000,
"shift": 3.0
}
else:
raise ValueError(f"Sampler {sampler} not supported")
config = copy.deepcopy(config_to_use)