mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
141 lines
4.0 KiB
Python
141 lines
4.0 KiB
Python
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import transforms
|
|
import sys
|
|
import os
|
|
import cv2
|
|
import random
|
|
from transformers import CLIPImageProcessor
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
import torchvision.transforms.functional
|
|
from toolkit.image_utils import save_tensors, show_img, show_tensors
|
|
|
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
|
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
|
|
trigger_dataloader_setup_epoch
|
|
from toolkit.config_modules import DatasetConfig
|
|
import argparse
|
|
from tqdm import tqdm
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('dataset_folder', type=str, default='input')
|
|
parser.add_argument('--epochs', type=int, default=1)
|
|
parser.add_argument('--num_frames', type=int, default=1)
|
|
parser.add_argument('--output_path', type=str, default=None)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.output_path is not None:
|
|
args.output_path = os.path.abspath(args.output_path)
|
|
os.makedirs(args.output_path, exist_ok=True)
|
|
|
|
dataset_folder = args.dataset_folder
|
|
resolution = 512
|
|
bucket_tolerance = 64
|
|
batch_size = 1
|
|
|
|
clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
|
|
|
class FakeAdapter:
|
|
def __init__(self):
|
|
self.clip_image_processor = clip_processor
|
|
|
|
|
|
## make fake sd
|
|
class FakeSD:
|
|
def __init__(self):
|
|
self.adapter = FakeAdapter()
|
|
|
|
|
|
|
|
|
|
dataset_config = DatasetConfig(
|
|
dataset_path=dataset_folder,
|
|
# clip_image_path=dataset_folder,
|
|
# square_crop=True,
|
|
resolution=resolution,
|
|
# caption_ext='json',
|
|
default_caption='default',
|
|
# clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
|
|
buckets=True,
|
|
bucket_tolerance=bucket_tolerance,
|
|
shrink_video_to_frames=True,
|
|
num_frames=args.num_frames,
|
|
# poi='person',
|
|
# shuffle_augmentations=True,
|
|
# augmentations=[
|
|
# {
|
|
# 'method': 'Posterize',
|
|
# 'num_bits': [(0, 4), (0, 4), (0, 4)],
|
|
# 'p': 1.0
|
|
# },
|
|
#
|
|
# ]
|
|
)
|
|
|
|
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD())
|
|
|
|
|
|
# run through an epoch ang check sizes
|
|
dataloader_iterator = iter(dataloader)
|
|
idx = 0
|
|
for epoch in range(args.epochs):
|
|
for batch in tqdm(dataloader):
|
|
batch: 'DataLoaderBatchDTO'
|
|
img_batch = batch.tensor
|
|
frames = 1
|
|
if len(img_batch.shape) == 5:
|
|
frames = img_batch.shape[1]
|
|
batch_size, frames, channels, height, width = img_batch.shape
|
|
else:
|
|
batch_size, channels, height, width = img_batch.shape
|
|
|
|
# img_batch = color_block_imgs(img_batch, neg1_1=True)
|
|
|
|
# chunks = torch.chunk(img_batch, batch_size, dim=0)
|
|
# # put them so they are size by side
|
|
# big_img = torch.cat(chunks, dim=3)
|
|
# big_img = big_img.squeeze(0)
|
|
#
|
|
# control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
|
|
# big_control_img = torch.cat(control_chunks, dim=3)
|
|
# big_control_img = big_control_img.squeeze(0) * 2 - 1
|
|
#
|
|
#
|
|
# # resize control image
|
|
# big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
|
|
#
|
|
# big_img = torch.cat([big_img, big_control_img], dim=2)
|
|
#
|
|
# min_val = big_img.min()
|
|
# max_val = big_img.max()
|
|
#
|
|
# big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
|
|
|
big_img = img_batch
|
|
# big_img = big_img.clamp(-1, 1)
|
|
if args.output_path is not None:
|
|
save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png'))
|
|
else:
|
|
show_tensors(big_img)
|
|
|
|
# convert to image
|
|
# img = transforms.ToPILImage()(big_img)
|
|
#
|
|
# show_img(img)
|
|
|
|
time.sleep(0.2)
|
|
idx += 1
|
|
# if not last epoch
|
|
if epoch < args.epochs - 1:
|
|
trigger_dataloader_setup_epoch(dataloader)
|
|
|
|
cv2.destroyAllWindows()
|
|
|
|
print('done')
|