mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 16:59:22 +00:00
Added ability to load video datasets and train with them
This commit is contained in:
@@ -13,7 +13,7 @@ from transformers import CLIPImageProcessor
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||
import torchvision.transforms.functional
|
||||
from toolkit.image_utils import show_img, show_tensors
|
||||
from toolkit.image_utils import save_tensors, show_img, show_tensors
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
@@ -28,13 +28,18 @@ from tqdm import tqdm
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dataset_folder', type=str, default='input')
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
|
||||
parser.add_argument('--num_frames', type=int, default=1)
|
||||
parser.add_argument('--output_path', type=str, default=None)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output_path is not None:
|
||||
args.output_path = os.path.abspath(args.output_path)
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
resolution = 1024
|
||||
resolution = 512
|
||||
bucket_tolerance = 64
|
||||
batch_size = 1
|
||||
|
||||
@@ -63,6 +68,8 @@ dataset_config = DatasetConfig(
|
||||
# clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
shrink_video_to_frames=True,
|
||||
num_frames=args.num_frames,
|
||||
# poi='person',
|
||||
# shuffle_augmentations=True,
|
||||
# augmentations=[
|
||||
@@ -80,11 +87,17 @@ dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_si
|
||||
|
||||
# run through an epoch ang check sizes
|
||||
dataloader_iterator = iter(dataloader)
|
||||
idx = 0
|
||||
for epoch in range(args.epochs):
|
||||
for batch in tqdm(dataloader):
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
batch_size, channels, height, width = img_batch.shape
|
||||
frames = 1
|
||||
if len(img_batch.shape) == 5:
|
||||
frames = img_batch.shape[1]
|
||||
batch_size, frames, channels, height, width = img_batch.shape
|
||||
else:
|
||||
batch_size, channels, height, width = img_batch.shape
|
||||
|
||||
# img_batch = color_block_imgs(img_batch, neg1_1=True)
|
||||
|
||||
@@ -110,15 +123,18 @@ for epoch in range(args.epochs):
|
||||
|
||||
big_img = img_batch
|
||||
# big_img = big_img.clamp(-1, 1)
|
||||
if args.output_path is not None:
|
||||
save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png'))
|
||||
else:
|
||||
show_tensors(big_img)
|
||||
|
||||
show_tensors(big_img)
|
||||
# convert to image
|
||||
# img = transforms.ToPILImage()(big_img)
|
||||
#
|
||||
# show_img(img)
|
||||
|
||||
# convert to image
|
||||
# img = transforms.ToPILImage()(big_img)
|
||||
#
|
||||
# show_img(img)
|
||||
|
||||
time.sleep(0.2)
|
||||
time.sleep(0.2)
|
||||
idx += 1
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
Reference in New Issue
Block a user