Files
ai-toolkit/testing/test_bucket_dataloader.py

38 lines
985 B
Python

from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm
# make sure we can import from the toolkit
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets
from toolkit.config_modules import DatasetConfig
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input')
args = parser.parse_args()
dataset_folder = args.dataset_folder
resolution = 512
bucket_tolerance = 64
batch_size = 4
dataset_config = DatasetConfig(
folder_path=dataset_folder,
resolution=resolution,
caption_ext='txt',
default_caption='default',
buckets=True,
bucket_tolerance=bucket_tolerance,
)
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
# run through an epoch ang check sizes
for batch in dataloader:
print(list(batch[0].shape))
print('done')