mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Bug fixes
This commit is contained in:
@@ -557,6 +557,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
def hook_before_train_loop(self):
|
def hook_before_train_loop(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def ensure_params_requires_grad(self):
|
||||||
|
# get param groups
|
||||||
|
for group in self.optimizer.param_groups:
|
||||||
|
for param in group['params']:
|
||||||
|
param.requires_grad = True
|
||||||
|
|
||||||
def setup_ema(self):
|
def setup_ema(self):
|
||||||
if self.train_config.ema_config.use_ema:
|
if self.train_config.ema_config.use_ema:
|
||||||
# our params are in groups. We need them as a single iterable
|
# our params are in groups. We need them as a single iterable
|
||||||
@@ -1535,6 +1541,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# print(f"Compiling Model")
|
# print(f"Compiling Model")
|
||||||
# torch.compile(self.sd.unet, dynamic=True)
|
# torch.compile(self.sd.unet, dynamic=True)
|
||||||
|
|
||||||
|
# make sure all params require grad
|
||||||
|
self.ensure_params_requires_grad()
|
||||||
|
|
||||||
|
|
||||||
###################################################################
|
###################################################################
|
||||||
# TRAIN LOOP
|
# TRAIN LOOP
|
||||||
###################################################################
|
###################################################################
|
||||||
@@ -1652,6 +1662,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if self.train_config.free_u:
|
if self.train_config.free_u:
|
||||||
self.sd.pipeline.disable_freeu()
|
self.sd.pipeline.disable_freeu()
|
||||||
self.sample(self.step_num)
|
self.sample(self.step_num)
|
||||||
|
self.ensure_params_requires_grad()
|
||||||
self.progress_bar.unpause()
|
self.progress_bar.unpause()
|
||||||
|
|
||||||
if is_save_step:
|
if is_save_step:
|
||||||
@@ -1659,6 +1670,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.progress_bar.pause()
|
self.progress_bar.pause()
|
||||||
self.print(f"Saving at step {self.step_num}")
|
self.print(f"Saving at step {self.step_num}")
|
||||||
self.save(self.step_num)
|
self.save(self.step_num)
|
||||||
|
self.ensure_params_requires_grad()
|
||||||
self.progress_bar.unpause()
|
self.progress_bar.unpause()
|
||||||
|
|
||||||
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class GenerateConfig:
|
|||||||
|
|
||||||
self.random_prompts = kwargs.get('random_prompts', False)
|
self.random_prompts = kwargs.get('random_prompts', False)
|
||||||
self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1)
|
self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1)
|
||||||
self.max_images = kwargs.get('max_prompts', 10000)
|
self.max_images = kwargs.get('max_images', 10000)
|
||||||
|
|
||||||
if self.random_prompts:
|
if self.random_prompts:
|
||||||
self.prompts = []
|
self.prompts = []
|
||||||
|
|||||||
@@ -25,3 +25,4 @@ python-dotenv
|
|||||||
bitsandbytes
|
bitsandbytes
|
||||||
hf_transfer
|
hf_transfer
|
||||||
lpips
|
lpips
|
||||||
|
pytorch_fid
|
||||||
113
testing/test_vae.py
Normal file
113
testing/test_vae.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torchvision.transforms import Resize, ToTensor
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from pytorch_fid import fid_score
|
||||||
|
from skimage.metrics import peak_signal_noise_ratio as psnr
|
||||||
|
import lpips
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
def load_images(folder_path):
|
||||||
|
images = []
|
||||||
|
for filename in os.listdir(folder_path):
|
||||||
|
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||||||
|
img_path = os.path.join(folder_path, filename)
|
||||||
|
images.append(img_path)
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def paramiter_count(model):
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
paramiter_count = 0
|
||||||
|
for key in state_dict:
|
||||||
|
paramiter_count += torch.numel(state_dict[key])
|
||||||
|
return int(paramiter_count)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_metrics(vae, images, max_imgs=-1):
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
vae = vae.to(device)
|
||||||
|
lpips_model = lpips.LPIPS(net='alex').to(device)
|
||||||
|
|
||||||
|
rfid_scores = []
|
||||||
|
psnr_scores = []
|
||||||
|
lpips_scores = []
|
||||||
|
|
||||||
|
# transform = transforms.Compose([
|
||||||
|
# transforms.Resize(256, antialias=True),
|
||||||
|
# transforms.CenterCrop(256)
|
||||||
|
# ])
|
||||||
|
# needs values between -1 and 1
|
||||||
|
to_tensor = ToTensor()
|
||||||
|
|
||||||
|
if max_imgs > 0 and len(images) > max_imgs:
|
||||||
|
images = images[:max_imgs]
|
||||||
|
|
||||||
|
for img_path in tqdm(images):
|
||||||
|
try:
|
||||||
|
img = Image.open(img_path).convert('RGB')
|
||||||
|
# img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device)
|
||||||
|
img_tensor = to_tensor(img).unsqueeze(0).to(device)
|
||||||
|
img_tensor = 2 * img_tensor - 1
|
||||||
|
# if width or height is not divisible by 8, crop it
|
||||||
|
if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0:
|
||||||
|
img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {img_path}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample
|
||||||
|
|
||||||
|
# Calculate rFID
|
||||||
|
# rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed)
|
||||||
|
# rfid_scores.append(rfid)
|
||||||
|
|
||||||
|
# Calculate PSNR
|
||||||
|
psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy())
|
||||||
|
psnr_scores.append(psnr_val)
|
||||||
|
|
||||||
|
# Calculate LPIPS
|
||||||
|
lpips_val = lpips_model(img_tensor, reconstructed).item()
|
||||||
|
lpips_scores.append(lpips_val)
|
||||||
|
|
||||||
|
# avg_rfid = sum(rfid_scores) / len(rfid_scores)
|
||||||
|
avg_rfid = 0
|
||||||
|
avg_psnr = sum(psnr_scores) / len(psnr_scores)
|
||||||
|
avg_lpips = sum(lpips_scores) / len(lpips_scores)
|
||||||
|
|
||||||
|
return avg_rfid, avg_psnr, avg_lpips
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions")
|
||||||
|
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
|
||||||
|
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
|
||||||
|
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if os.path.isfile(args.vae_path):
|
||||||
|
vae = AutoencoderKL.from_single_file(args.vae_path)
|
||||||
|
else:
|
||||||
|
vae = AutoencoderKL.from_pretrained(args.vae_path)
|
||||||
|
vae.eval()
|
||||||
|
vae = vae.to(device)
|
||||||
|
print(f"Model has {paramiter_count(vae)} parameters")
|
||||||
|
images = load_images(args.image_folder)
|
||||||
|
|
||||||
|
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
|
||||||
|
|
||||||
|
# print(f"Average rFID: {avg_rfid}")
|
||||||
|
print(f"Average PSNR: {avg_psnr}")
|
||||||
|
print(f"Average LPIPS: {avg_lpips}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -403,11 +403,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
|
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
|
||||||
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
|
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
|
||||||
|
|
||||||
transformer.pos_embed.orig_forward = transformer.pos_embed.forward
|
transformer.pos_embed = self.transformer_pos_embed
|
||||||
transformer.proj_out.orig_forward = transformer.proj_out.forward
|
transformer.proj_out = self.transformer_proj_out
|
||||||
|
|
||||||
transformer.pos_embed.forward = self.transformer_pos_embed.forward
|
|
||||||
transformer.proj_out.forward = self.transformer_proj_out.forward
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
unet: UNet2DConditionModel = unet
|
unet: UNet2DConditionModel = unet
|
||||||
@@ -417,10 +414,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
# clone these and replace their forwards with ours
|
# clone these and replace their forwards with ours
|
||||||
self.unet_conv_in = copy.deepcopy(unet_conv_in)
|
self.unet_conv_in = copy.deepcopy(unet_conv_in)
|
||||||
self.unet_conv_out = copy.deepcopy(unet_conv_out)
|
self.unet_conv_out = copy.deepcopy(unet_conv_out)
|
||||||
unet.conv_in.orig_forward = unet_conv_in.forward
|
unet.conv_in = self.unet_conv_in
|
||||||
unet_conv_out.orig_forward = unet_conv_out.forward
|
unet.conv_out = self.unet_conv_out
|
||||||
unet.conv_in.forward = self.unet_conv_in.forward
|
|
||||||
unet.conv_out.forward = self.unet_conv_out.forward
|
|
||||||
|
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
# call Lora prepare_optimizer_params
|
# call Lora prepare_optimizer_params
|
||||||
|
|||||||
@@ -123,6 +123,8 @@ def get_sampler(
|
|||||||
"num_train_timesteps": 1000,
|
"num_train_timesteps": 1000,
|
||||||
"shift": 3.0
|
"shift": 3.0
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Sampler {sampler} not supported")
|
||||||
|
|
||||||
|
|
||||||
config = copy.deepcopy(config_to_use)
|
config = copy.deepcopy(config_to_use)
|
||||||
|
|||||||
Reference in New Issue
Block a user