mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Working multi gpu training. Still need a lot of tweaks and testing.
This commit is contained in:
@@ -20,6 +20,8 @@ from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.print import print_acc
|
||||
from toolkit.accelerator import get_accelerator
|
||||
|
||||
import platform
|
||||
|
||||
@@ -90,7 +92,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
||||
|
||||
# this might take a while
|
||||
print(f" - Preprocessing image dimensions")
|
||||
print_acc(f" - Preprocessing image dimensions")
|
||||
new_file_list = []
|
||||
bad_count = 0
|
||||
for file in tqdm(self.file_list):
|
||||
@@ -102,8 +104,8 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
|
||||
self.file_list = new_file_list
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
print(f" - Found {bad_count} images that are too small")
|
||||
print_acc(f" - Found {len(self.file_list)} images")
|
||||
print_acc(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.path}"
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
@@ -128,8 +130,8 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
try:
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
except Exception as e:
|
||||
print(f"Error opening image: {img_path}")
|
||||
print(e)
|
||||
print_acc(f"Error opening image: {img_path}")
|
||||
print_acc(e)
|
||||
# make a noise image if we can't open it
|
||||
img = Image.fromarray(np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8))
|
||||
|
||||
@@ -140,7 +142,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
if self.random_crop:
|
||||
if self.random_scale and min_img_size > self.resolution:
|
||||
if min_img_size < self.resolution:
|
||||
print(
|
||||
print_acc(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
@@ -243,11 +245,11 @@ class PairedImageDataset(Dataset):
|
||||
matched_files = [t for t in (set(tuple(i) for i in matched_files))]
|
||||
|
||||
self.file_list = matched_files
|
||||
print(f" - Found {len(self.file_list)} matching pairs")
|
||||
print_acc(f" - Found {len(self.file_list)} matching pairs")
|
||||
else:
|
||||
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
|
||||
file.lower().endswith(supported_exts)]
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
print_acc(f" - Found {len(self.file_list)} images")
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
@@ -435,11 +437,12 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
])
|
||||
|
||||
# this might take a while
|
||||
print(f"Dataset: {self.dataset_path}")
|
||||
print(f" - Preprocessing image dimensions")
|
||||
print_acc(f"Dataset: {self.dataset_path}")
|
||||
print_acc(f" - Preprocessing image dimensions")
|
||||
dataset_folder = self.dataset_path
|
||||
if not os.path.isdir(self.dataset_path):
|
||||
dataset_folder = os.path.dirname(dataset_folder)
|
||||
|
||||
dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json')
|
||||
dataloader_version = "0.1.1"
|
||||
if os.path.exists(dataset_size_file):
|
||||
@@ -448,12 +451,12 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.size_database = json.load(f)
|
||||
|
||||
if "__version__" not in self.size_database or self.size_database["__version__"] != dataloader_version:
|
||||
print("Upgrading size database to new version")
|
||||
print_acc("Upgrading size database to new version")
|
||||
# old version, delete and recreate
|
||||
self.size_database = {}
|
||||
except Exception as e:
|
||||
print(f"Error loading size database: {dataset_size_file}")
|
||||
print(e)
|
||||
print_acc(f"Error loading size database: {dataset_size_file}")
|
||||
print_acc(e)
|
||||
self.size_database = {}
|
||||
else:
|
||||
self.size_database = {}
|
||||
@@ -473,22 +476,22 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
)
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f"Error processing image: {file}")
|
||||
print(e)
|
||||
print_acc(traceback.format_exc())
|
||||
print_acc(f"Error processing image: {file}")
|
||||
print_acc(e)
|
||||
bad_count += 1
|
||||
|
||||
# save the size database
|
||||
with open(dataset_size_file, 'w') as f:
|
||||
json.dump(self.size_database, f)
|
||||
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
# print(f" - Found {bad_count} images that are too small")
|
||||
print_acc(f" - Found {len(self.file_list)} images")
|
||||
# print_acc(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
# handle x axis flips
|
||||
if self.dataset_config.flip_x:
|
||||
print(" - adding x axis flips")
|
||||
print_acc(" - adding x axis flips")
|
||||
current_file_list = [x for x in self.file_list]
|
||||
for file_item in current_file_list:
|
||||
# create a copy that is flipped on the x axis
|
||||
@@ -498,7 +501,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
|
||||
# handle y axis flips
|
||||
if self.dataset_config.flip_y:
|
||||
print(" - adding y axis flips")
|
||||
print_acc(" - adding y axis flips")
|
||||
current_file_list = [x for x in self.file_list]
|
||||
for file_item in current_file_list:
|
||||
# create a copy that is flipped on the y axis
|
||||
@@ -507,7 +510,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
|
||||
self.file_list.append(new_file_item)
|
||||
|
||||
if self.dataset_config.flip_x or self.dataset_config.flip_y:
|
||||
print(f" - Found {len(self.file_list)} images after adding flips")
|
||||
print_acc(f" - Found {len(self.file_list)} images after adding flips")
|
||||
|
||||
|
||||
self.setup_epoch()
|
||||
|
||||
Reference in New Issue
Block a user