Working multi gpu training. Still need a lot of tweaks and testing.

This commit is contained in:
Jaret Burkett
2025-01-25 16:46:20 -07:00
parent 441474e81f
commit 5e663746b8
9 changed files with 432 additions and 294 deletions

View File

@@ -20,6 +20,8 @@ from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.print import print_acc
from toolkit.accelerator import get_accelerator
import platform
@@ -90,7 +92,7 @@ class ImageDataset(Dataset, CaptionMixin):
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
# this might take a while
print(f" - Preprocessing image dimensions")
print_acc(f" - Preprocessing image dimensions")
new_file_list = []
bad_count = 0
for file in tqdm(self.file_list):
@@ -102,8 +104,8 @@ class ImageDataset(Dataset, CaptionMixin):
self.file_list = new_file_list
print(f" - Found {len(self.file_list)} images")
print(f" - Found {bad_count} images that are too small")
print_acc(f" - Found {len(self.file_list)} images")
print_acc(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.path}"
self.transform = transforms.Compose([
@@ -128,8 +130,8 @@ class ImageDataset(Dataset, CaptionMixin):
try:
img = exif_transpose(Image.open(img_path)).convert('RGB')
except Exception as e:
print(f"Error opening image: {img_path}")
print(e)
print_acc(f"Error opening image: {img_path}")
print_acc(e)
# make a noise image if we can't open it
img = Image.fromarray(np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8))
@@ -140,7 +142,7 @@ class ImageDataset(Dataset, CaptionMixin):
if self.random_crop:
if self.random_scale and min_img_size > self.resolution:
if min_img_size < self.resolution:
print(
print_acc(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
scale_size = self.resolution
else:
@@ -243,11 +245,11 @@ class PairedImageDataset(Dataset):
matched_files = [t for t in (set(tuple(i) for i in matched_files))]
self.file_list = matched_files
print(f" - Found {len(self.file_list)} matching pairs")
print_acc(f" - Found {len(self.file_list)} matching pairs")
else:
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
file.lower().endswith(supported_exts)]
print(f" - Found {len(self.file_list)} images")
print_acc(f" - Found {len(self.file_list)} images")
self.transform = transforms.Compose([
transforms.ToTensor(),
@@ -435,11 +437,12 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
])
# this might take a while
print(f"Dataset: {self.dataset_path}")
print(f" - Preprocessing image dimensions")
print_acc(f"Dataset: {self.dataset_path}")
print_acc(f" - Preprocessing image dimensions")
dataset_folder = self.dataset_path
if not os.path.isdir(self.dataset_path):
dataset_folder = os.path.dirname(dataset_folder)
dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json')
dataloader_version = "0.1.1"
if os.path.exists(dataset_size_file):
@@ -448,12 +451,12 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
self.size_database = json.load(f)
if "__version__" not in self.size_database or self.size_database["__version__"] != dataloader_version:
print("Upgrading size database to new version")
print_acc("Upgrading size database to new version")
# old version, delete and recreate
self.size_database = {}
except Exception as e:
print(f"Error loading size database: {dataset_size_file}")
print(e)
print_acc(f"Error loading size database: {dataset_size_file}")
print_acc(e)
self.size_database = {}
else:
self.size_database = {}
@@ -473,22 +476,22 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
)
self.file_list.append(file_item)
except Exception as e:
print(traceback.format_exc())
print(f"Error processing image: {file}")
print(e)
print_acc(traceback.format_exc())
print_acc(f"Error processing image: {file}")
print_acc(e)
bad_count += 1
# save the size database
with open(dataset_size_file, 'w') as f:
json.dump(self.size_database, f)
print(f" - Found {len(self.file_list)} images")
# print(f" - Found {bad_count} images that are too small")
print_acc(f" - Found {len(self.file_list)} images")
# print_acc(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
# handle x axis flips
if self.dataset_config.flip_x:
print(" - adding x axis flips")
print_acc(" - adding x axis flips")
current_file_list = [x for x in self.file_list]
for file_item in current_file_list:
# create a copy that is flipped on the x axis
@@ -498,7 +501,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
# handle y axis flips
if self.dataset_config.flip_y:
print(" - adding y axis flips")
print_acc(" - adding y axis flips")
current_file_list = [x for x in self.file_list]
for file_item in current_file_list:
# create a copy that is flipped on the y axis
@@ -507,7 +510,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti
self.file_list.append(new_file_item)
if self.dataset_config.flip_x or self.dataset_config.flip_y:
print(f" - Found {len(self.file_list)} images after adding flips")
print_acc(f" - Found {len(self.file_list)} images after adding flips")
self.setup_epoch()