diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index 9c9adf65..b4537422 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -50,9 +50,13 @@ class FakeAdapter: class FakeSD: def __init__(self): self.adapter = FakeAdapter() + self.use_raw_control_images = False + + def encode_control_in_text_embeddings(self, *args, **kwargs): + return None - - + def get_bucket_divisibility(self): + return 32 dataset_config = DatasetConfig( dataset_path=dataset_folder, @@ -120,7 +124,11 @@ for epoch in range(args.epochs): big_img = img_batch # big_img = big_img.clamp(-1, 1) if args.output_path is not None: - save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png')) + if len(img_batch.shape) == 5: + # video + save_tensors(big_img, os.path.join(args.output_path, f'{idx}.webp'), fps=16) + else: + save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png')) else: show_tensors(big_img) diff --git a/toolkit/image_utils.py b/toolkit/image_utils.py index 0c536c43..68844e23 100644 --- a/toolkit/image_utils.py +++ b/toolkit/image_utils.py @@ -482,7 +482,7 @@ def show_tensors(imgs: torch.Tensor, name='AI Toolkit'): show_img(img_numpy[0], name=name) -def save_tensors(imgs: torch.Tensor, path='output.png'): +def save_tensors(imgs: torch.Tensor, path='output.png', fps=None): if len(imgs.shape) == 5 and imgs.shape[0] == 1: imgs = imgs.squeeze(0) if len(imgs.shape) == 4: @@ -490,17 +490,28 @@ def save_tensors(imgs: torch.Tensor, path='output.png'): else: img_list = [imgs] - img = torch.cat(img_list, dim=3) + num_frames = len(img_list) + print(f"Saving {num_frames} frames to {path} at {fps} fps") + if fps is not None and num_frames > 1: + img = torch.cat(img_list, dim=0) + else: + img = torch.cat(img_list, dim=3) img = img / 2 + 0.5 img_numpy = img.to(torch.float32).detach().cpu().numpy() img_numpy = np.clip(img_numpy, 0, 1) * 255 img_numpy = img_numpy.transpose(0, 2, 3, 1) img_numpy = img_numpy.astype(np.uint8) - # concat images to one - img_numpy = np.concatenate(img_numpy, axis=1) - # conver to pil - img_pil = PILImage.fromarray(img_numpy) - img_pil.save(path) + + if fps is not None and num_frames > 1: + img_list = [PILImage.fromarray(img_numpy[i]) for i in range(num_frames)] + duration = int(1000 / fps) + img_list[0].save(path, save_all=True, append_images=img_list[1:], duration=duration, loop=0, quality=95) + else: + # concat images to one + img_numpy = np.concatenate(img_numpy, axis=1) + # conver to pil + img_pil = PILImage.fromarray(img_numpy) + img_pil.save(path) def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'): if vae.device == 'cpu':