mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added cogview4. Loss still needs work.
This commit is contained in:
@@ -29,7 +29,7 @@ def paramiter_count(model):
|
||||
return int(paramiter_count)
|
||||
|
||||
|
||||
def calculate_metrics(vae, images, max_imgs=-1):
|
||||
def calculate_metrics(vae, images, max_imgs=-1, save_output=False):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
vae = vae.to(device)
|
||||
lpips_model = lpips.LPIPS(net='alex').to(device)
|
||||
@@ -44,6 +44,9 @@ def calculate_metrics(vae, images, max_imgs=-1):
|
||||
# ])
|
||||
# needs values between -1 and 1
|
||||
to_tensor = ToTensor()
|
||||
|
||||
# remove _reconstructed.png files
|
||||
images = [img for img in images if not img.endswith("_reconstructed.png")]
|
||||
|
||||
if max_imgs > 0 and len(images) > max_imgs:
|
||||
images = images[:max_imgs]
|
||||
@@ -82,6 +85,15 @@ def calculate_metrics(vae, images, max_imgs=-1):
|
||||
avg_rfid = 0
|
||||
avg_psnr = sum(psnr_scores) / len(psnr_scores)
|
||||
avg_lpips = sum(lpips_scores) / len(lpips_scores)
|
||||
|
||||
if save_output:
|
||||
filename_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
folder = os.path.dirname(img_path)
|
||||
save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png")
|
||||
reconstructed = (reconstructed + 1) / 2
|
||||
reconstructed = reconstructed.clamp(0, 1)
|
||||
reconstructed = transforms.ToPILImage()(reconstructed[0].cpu())
|
||||
reconstructed.save(save_path)
|
||||
|
||||
return avg_rfid, avg_psnr, avg_lpips
|
||||
|
||||
@@ -91,18 +103,23 @@ def main():
|
||||
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
|
||||
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
|
||||
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
|
||||
# boolean store true
|
||||
parser.add_argument("--save_output", action="store_true", help="Save the output images")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.isfile(args.vae_path):
|
||||
vae = AutoencoderKL.from_single_file(args.vae_path)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path)
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path)
|
||||
except:
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
||||
vae.eval()
|
||||
vae = vae.to(device)
|
||||
print(f"Model has {paramiter_count(vae)} parameters")
|
||||
images = load_images(args.image_folder)
|
||||
|
||||
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
|
||||
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output)
|
||||
|
||||
# print(f"Average rFID: {avg_rfid}")
|
||||
print(f"Average PSNR: {avg_psnr}")
|
||||
|
||||
Reference in New Issue
Block a user