mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Adjust dataloader tester to handle videos to test them
This commit is contained in:
@@ -50,9 +50,13 @@ class FakeAdapter:
|
||||
class FakeSD:
|
||||
def __init__(self):
|
||||
self.adapter = FakeAdapter()
|
||||
self.use_raw_control_images = False
|
||||
|
||||
def encode_control_in_text_embeddings(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
return 32
|
||||
|
||||
dataset_config = DatasetConfig(
|
||||
dataset_path=dataset_folder,
|
||||
@@ -120,7 +124,11 @@ for epoch in range(args.epochs):
|
||||
big_img = img_batch
|
||||
# big_img = big_img.clamp(-1, 1)
|
||||
if args.output_path is not None:
|
||||
save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png'))
|
||||
if len(img_batch.shape) == 5:
|
||||
# video
|
||||
save_tensors(big_img, os.path.join(args.output_path, f'{idx}.webp'), fps=16)
|
||||
else:
|
||||
save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png'))
|
||||
else:
|
||||
show_tensors(big_img)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user