Added cogview4. Loss still needs work.

This commit is contained in:
Jaret Burkett
2025-03-04 18:43:52 -07:00
parent c57434ad7b
commit 6f6fb90812
12 changed files with 661 additions and 152 deletions

View File

@@ -29,7 +29,7 @@ def paramiter_count(model):
return int(paramiter_count)
def calculate_metrics(vae, images, max_imgs=-1):
def calculate_metrics(vae, images, max_imgs=-1, save_output=False):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
lpips_model = lpips.LPIPS(net='alex').to(device)
@@ -44,6 +44,9 @@ def calculate_metrics(vae, images, max_imgs=-1):
# ])
# needs values between -1 and 1
to_tensor = ToTensor()
# remove _reconstructed.png files
images = [img for img in images if not img.endswith("_reconstructed.png")]
if max_imgs > 0 and len(images) > max_imgs:
images = images[:max_imgs]
@@ -82,6 +85,15 @@ def calculate_metrics(vae, images, max_imgs=-1):
avg_rfid = 0
avg_psnr = sum(psnr_scores) / len(psnr_scores)
avg_lpips = sum(lpips_scores) / len(lpips_scores)
if save_output:
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
folder = os.path.dirname(img_path)
save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png")
reconstructed = (reconstructed + 1) / 2
reconstructed = reconstructed.clamp(0, 1)
reconstructed = transforms.ToPILImage()(reconstructed[0].cpu())
reconstructed.save(save_path)
return avg_rfid, avg_psnr, avg_lpips
@@ -91,18 +103,23 @@ def main():
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
# boolean store true
parser.add_argument("--save_output", action="store_true", help="Save the output images")
args = parser.parse_args()
if os.path.isfile(args.vae_path):
vae = AutoencoderKL.from_single_file(args.vae_path)
else:
vae = AutoencoderKL.from_pretrained(args.vae_path)
try:
vae = AutoencoderKL.from_pretrained(args.vae_path)
except:
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.eval()
vae = vae.to(device)
print(f"Model has {paramiter_count(vae)} parameters")
images = load_images(args.image_folder)
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output)
# print(f"Average rFID: {avg_rfid}")
print(f"Average PSNR: {avg_psnr}")