mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-12 22:19:48 +00:00
Fixed big issue with bucketing dataloader and added random cripping to a point of interest
This commit is contained in:
@@ -2,6 +2,7 @@ import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
import os
|
||||
@@ -16,12 +17,14 @@ sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from library.model_util import load_vae
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets
|
||||
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
|
||||
trigger_dataloader_setup_epoch
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dataset_folder', type=str, default='input')
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -34,38 +37,44 @@ batch_size = 4
|
||||
dataset_config = DatasetConfig(
|
||||
dataset_path=dataset_folder,
|
||||
resolution=resolution,
|
||||
caption_ext='txt',
|
||||
caption_ext='json',
|
||||
default_caption='default',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
augments=['ColorJitter', 'RandomEqualize'],
|
||||
augments=['ColorJitter'],
|
||||
poi='person'
|
||||
|
||||
)
|
||||
|
||||
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
|
||||
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
|
||||
|
||||
|
||||
# run through an epoch ang check sizes
|
||||
for batch in dataloader:
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
dataloader_iterator = iter(dataloader)
|
||||
for epoch in range(args.epochs):
|
||||
for batch in dataloader:
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
big_img = torch.cat(chunks, dim=3)
|
||||
big_img = big_img.squeeze(0)
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
big_img = torch.cat(chunks, dim=3)
|
||||
big_img = big_img.squeeze(0)
|
||||
|
||||
min_val = big_img.min()
|
||||
max_val = big_img.max()
|
||||
min_val = big_img.min()
|
||||
max_val = big_img.max()
|
||||
|
||||
big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
|
||||
show_img(img)
|
||||
show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
time.sleep(1.0)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user