mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Adjust dataloader tester to handle videos to test them
This commit is contained in:
@@ -50,9 +50,13 @@ class FakeAdapter:
|
|||||||
class FakeSD:
|
class FakeSD:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.adapter = FakeAdapter()
|
self.adapter = FakeAdapter()
|
||||||
|
self.use_raw_control_images = False
|
||||||
|
|
||||||
|
def encode_control_in_text_embeddings(self, *args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_bucket_divisibility(self):
|
||||||
|
return 32
|
||||||
|
|
||||||
dataset_config = DatasetConfig(
|
dataset_config = DatasetConfig(
|
||||||
dataset_path=dataset_folder,
|
dataset_path=dataset_folder,
|
||||||
@@ -120,7 +124,11 @@ for epoch in range(args.epochs):
|
|||||||
big_img = img_batch
|
big_img = img_batch
|
||||||
# big_img = big_img.clamp(-1, 1)
|
# big_img = big_img.clamp(-1, 1)
|
||||||
if args.output_path is not None:
|
if args.output_path is not None:
|
||||||
save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png'))
|
if len(img_batch.shape) == 5:
|
||||||
|
# video
|
||||||
|
save_tensors(big_img, os.path.join(args.output_path, f'{idx}.webp'), fps=16)
|
||||||
|
else:
|
||||||
|
save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png'))
|
||||||
else:
|
else:
|
||||||
show_tensors(big_img)
|
show_tensors(big_img)
|
||||||
|
|
||||||
|
|||||||
@@ -482,7 +482,7 @@ def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
|
|||||||
|
|
||||||
show_img(img_numpy[0], name=name)
|
show_img(img_numpy[0], name=name)
|
||||||
|
|
||||||
def save_tensors(imgs: torch.Tensor, path='output.png'):
|
def save_tensors(imgs: torch.Tensor, path='output.png', fps=None):
|
||||||
if len(imgs.shape) == 5 and imgs.shape[0] == 1:
|
if len(imgs.shape) == 5 and imgs.shape[0] == 1:
|
||||||
imgs = imgs.squeeze(0)
|
imgs = imgs.squeeze(0)
|
||||||
if len(imgs.shape) == 4:
|
if len(imgs.shape) == 4:
|
||||||
@@ -490,17 +490,28 @@ def save_tensors(imgs: torch.Tensor, path='output.png'):
|
|||||||
else:
|
else:
|
||||||
img_list = [imgs]
|
img_list = [imgs]
|
||||||
|
|
||||||
img = torch.cat(img_list, dim=3)
|
num_frames = len(img_list)
|
||||||
|
print(f"Saving {num_frames} frames to {path} at {fps} fps")
|
||||||
|
if fps is not None and num_frames > 1:
|
||||||
|
img = torch.cat(img_list, dim=0)
|
||||||
|
else:
|
||||||
|
img = torch.cat(img_list, dim=3)
|
||||||
img = img / 2 + 0.5
|
img = img / 2 + 0.5
|
||||||
img_numpy = img.to(torch.float32).detach().cpu().numpy()
|
img_numpy = img.to(torch.float32).detach().cpu().numpy()
|
||||||
img_numpy = np.clip(img_numpy, 0, 1) * 255
|
img_numpy = np.clip(img_numpy, 0, 1) * 255
|
||||||
img_numpy = img_numpy.transpose(0, 2, 3, 1)
|
img_numpy = img_numpy.transpose(0, 2, 3, 1)
|
||||||
img_numpy = img_numpy.astype(np.uint8)
|
img_numpy = img_numpy.astype(np.uint8)
|
||||||
# concat images to one
|
|
||||||
img_numpy = np.concatenate(img_numpy, axis=1)
|
if fps is not None and num_frames > 1:
|
||||||
# conver to pil
|
img_list = [PILImage.fromarray(img_numpy[i]) for i in range(num_frames)]
|
||||||
img_pil = PILImage.fromarray(img_numpy)
|
duration = int(1000 / fps)
|
||||||
img_pil.save(path)
|
img_list[0].save(path, save_all=True, append_images=img_list[1:], duration=duration, loop=0, quality=95)
|
||||||
|
else:
|
||||||
|
# concat images to one
|
||||||
|
img_numpy = np.concatenate(img_numpy, axis=1)
|
||||||
|
# conver to pil
|
||||||
|
img_pil = PILImage.fromarray(img_numpy)
|
||||||
|
img_pil.save(path)
|
||||||
|
|
||||||
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
|
def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
|
||||||
if vae.device == 'cpu':
|
if vae.device == 'cpu':
|
||||||
|
|||||||
Reference in New Issue
Block a user