mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Work on ipadapters and custom adapters
This commit is contained in:
@@ -21,12 +21,14 @@ from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets,
|
||||
trigger_dataloader_setup_epoch
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dataset_folder', type=str, default='input')
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
@@ -40,27 +42,27 @@ batch_size = 1
|
||||
dataset_config = DatasetConfig(
|
||||
dataset_path=dataset_folder,
|
||||
resolution=resolution,
|
||||
caption_ext='json',
|
||||
# caption_ext='json',
|
||||
default_caption='default',
|
||||
clip_image_path='/mnt/Datasets/face_pairs2/control_clean',
|
||||
# clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
poi='person',
|
||||
augmentations=[
|
||||
{
|
||||
'method': 'RandomBrightnessContrast',
|
||||
'brightness_limit': (-0.3, 0.3),
|
||||
'contrast_limit': (-0.3, 0.3),
|
||||
'brightness_by_max': False,
|
||||
'p': 1.0
|
||||
},
|
||||
{
|
||||
'method': 'HueSaturationValue',
|
||||
'hue_shift_limit': (-0, 0),
|
||||
'sat_shift_limit': (-40, 40),
|
||||
'val_shift_limit': (-40, 40),
|
||||
'p': 1.0
|
||||
},
|
||||
# poi='person',
|
||||
# augmentations=[
|
||||
# {
|
||||
# 'method': 'RandomBrightnessContrast',
|
||||
# 'brightness_limit': (-0.3, 0.3),
|
||||
# 'contrast_limit': (-0.3, 0.3),
|
||||
# 'brightness_by_max': False,
|
||||
# 'p': 1.0
|
||||
# },
|
||||
# {
|
||||
# 'method': 'HueSaturationValue',
|
||||
# 'hue_shift_limit': (-0, 0),
|
||||
# 'sat_shift_limit': (-40, 40),
|
||||
# 'val_shift_limit': (-40, 40),
|
||||
# 'p': 1.0
|
||||
# },
|
||||
# {
|
||||
# 'method': 'RGBShift',
|
||||
# 'r_shift_limit': (-20, 20),
|
||||
@@ -68,7 +70,7 @@ dataset_config = DatasetConfig(
|
||||
# 'b_shift_limit': (-20, 20),
|
||||
# 'p': 1.0
|
||||
# },
|
||||
]
|
||||
# ]
|
||||
|
||||
|
||||
)
|
||||
@@ -79,7 +81,7 @@ dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_si
|
||||
# run through an epoch ang check sizes
|
||||
dataloader_iterator = iter(dataloader)
|
||||
for epoch in range(args.epochs):
|
||||
for batch in dataloader:
|
||||
for batch in tqdm(dataloader):
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
|
||||
@@ -98,7 +100,7 @@ for epoch in range(args.epochs):
|
||||
|
||||
show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
# time.sleep(0.1)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
Reference in New Issue
Block a user