Adjust dataloader tester to handle videos to test them

This commit is contained in:
Jaret Burkett
2025-10-21 14:47:23 -06:00
parent 0d8a33dc16
commit 5123090f6c
2 changed files with 29 additions and 10 deletions

View File

@@ -482,7 +482,7 @@ def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
show_img(img_numpy[0], name=name)
def save_tensors(imgs: torch.Tensor, path='output.png'):
def save_tensors(imgs: torch.Tensor, path='output.png', fps=None):
if len(imgs.shape) == 5 and imgs.shape[0] == 1:
imgs = imgs.squeeze(0)
if len(imgs.shape) == 4:
@@ -490,17 +490,28 @@ def save_tensors(imgs: torch.Tensor, path='output.png'):
else:
img_list = [imgs]
img = torch.cat(img_list, dim=3)
num_frames = len(img_list)
print(f"Saving {num_frames} frames to {path} at {fps} fps")
if fps is not None and num_frames > 1:
img = torch.cat(img_list, dim=0)
else:
img = torch.cat(img_list, dim=3)
img = img / 2 + 0.5
img_numpy = img.to(torch.float32).detach().cpu().numpy()
img_numpy = np.clip(img_numpy, 0, 1) * 255
img_numpy = img_numpy.transpose(0, 2, 3, 1)
img_numpy = img_numpy.astype(np.uint8)
# concat images to one
img_numpy = np.concatenate(img_numpy, axis=1)
# conver to pil
img_pil = PILImage.fromarray(img_numpy)
img_pil.save(path)
if fps is not None and num_frames > 1:
img_list = [PILImage.fromarray(img_numpy[i]) for i in range(num_frames)]
duration = int(1000 / fps)
img_list[0].save(path, save_all=True, append_images=img_list[1:], duration=duration, loop=0, quality=95)
else:
# concat images to one
img_numpy = np.concatenate(img_numpy, axis=1)
# conver to pil
img_pil = PILImage.fromarray(img_numpy)
img_pil.save(path)
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
if vae.device == 'cpu':