diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 5a6287ec..801337f2 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -807,10 +807,12 @@ class BaseSDTrainProcess(BaseTrainProcess): with self.timer('prepare_latents'): dtype = get_torch_dtype(self.train_config.dtype) imgs = None + is_reg = any(batch.get_is_reg_list()) if batch.tensor is not None: imgs = batch.tensor imgs = imgs.to(self.device_torch, dtype=dtype) - if self.train_config.img_multiplier is not None: + # dont adjust for regs. + if self.train_config.img_multiplier is not None and not is_reg: # do it ad contrast imgs = reduce_contrast(imgs, self.train_config.img_multiplier) if batch.latents is not None: @@ -1495,8 +1497,10 @@ class BaseSDTrainProcess(BaseTrainProcess): try: print(f"Loading optimizer state from {optimizer_state_file_path}") - optimizer_state_dict = torch.load(optimizer_state_file_path) + optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + flush() except Exception as e: print(f"Failed to load optimizer state from {optimizer_state_file_path}") print(e) diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index a430d707..6be8bddd 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -8,6 +8,7 @@ import sys import os import cv2 import random +from transformers import CLIPImageProcessor sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from toolkit.paths import SD_SCRIPTS_ROOT @@ -37,12 +38,25 @@ resolution = 512 bucket_tolerance = 64 batch_size = 1 +clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16") + +class FakeAdapter: + def __init__(self): + self.clip_image_processor = clip_processor + + +## make fake sd +class FakeSD: + def __init__(self): + self.adapter = FakeAdapter() + + -## dataset_config = DatasetConfig( dataset_path=dataset_folder, - control_path=dataset_folder, + clip_image_path=dataset_folder, + square_crop=True, resolution=resolution, # caption_ext='json', default_caption='default', @@ -61,123 +75,7 @@ dataset_config = DatasetConfig( # ] ) -dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size) - -def random_blur(img, min_kernel_size=3, max_kernel_size=23, p=0.5): - if random.random() < p: - kernel_size = random.randint(min_kernel_size, max_kernel_size) - # make sure it is odd - if kernel_size % 2 == 0: - kernel_size += 1 - img = torchvision.transforms.functional.gaussian_blur(img, kernel_size=kernel_size) - return img - -def quantize(image, palette): - """ - Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient. - Only works for one image i.e. CHW. Does NOT work for batches. - ref https://discuss.pytorch.org/t/color-quantization/104528/4 - """ - - orig_dtype = image.dtype - - C, H, W = image.shape - n_colors = palette.shape[0] - - # Easier to work with list of colors - flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C] - - # Repeat image so that there are n_color number of columns of the same image - flat_img_per_color = flat_img.unsqueeze(1).expand(-1, n_colors, -1) # [H*W, C] -> [H*W, n_colors, C] - - # Get euclidean distance between each pixel in each column and the column's respective color - # i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc. - squared_distance = (flat_img_per_color - palette.unsqueeze(0)) ** 2 - euclidean_distance = torch.sqrt(torch.sum(squared_distance, dim=-1) + 1e-8) # [H*W, n_colors, C] -> [H*W, n_colors] - - # Get the shortest distance (one value per row (H*W) is selected) - min_distances, min_indices = torch.min(euclidean_distance, dim=-1) # [H*W, n_colors] -> [H*W] - - # Create a mask for the closest colors - one_hot_mask = torch.nn.functional.one_hot(min_indices, num_classes=n_colors).float() # [H*W, n_colors] - - # Multiply the mask with the palette colors to get the quantized image - quantized = torch.matmul(one_hot_mask, palette) # [H*W, n_colors] @ [n_colors, C] -> [H*W, C] - - # Reshape it back to the original input format. - quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W] - - return quantized_img.to(orig_dtype) - - - -def color_block_imgs(img, neg1_1=False): - # expects values 0 - 1 - orig_dtype = img.dtype - if neg1_1: - img = img * 0.5 + 0.5 - - img = img * 255 - img = img.clamp(0, 255) - img = img.to(torch.uint8) - - img_chunks = torch.chunk(img, img.shape[0], dim=0) - - posterized_chunks = [] - - for chunk in img_chunks: - img_size = (chunk.shape[2] + chunk.shape[3]) // 2 - # min kernel size of 1% of image, max 10% - min_kernel_size = int(img_size * 0.01) - max_kernel_size = int(img_size * 0.1) - - # blur first - chunk = random_blur(chunk, min_kernel_size=min_kernel_size, max_kernel_size=max_kernel_size, p=0.8) - num_colors = random.randint(1, 16) - - resize_to = 16 - # chunk = torchvision.transforms.functional.posterize(chunk, num_bits_to_use) - - # mean_color = [int(x.item()) for x in torch.mean(chunk.float(), dim=(0, 2, 3))] - - # shrink the image down to num_colors x num_colors - shrunk = torchvision.transforms.functional.resize(chunk, [resize_to, resize_to]) - - mean_color = [int(x.item()) for x in torch.mean(shrunk.float(), dim=(0, 2, 3))] - - colors = shrunk.view(3, -1).T - # remove duplicates - colors = torch.unique(colors, dim=0) - colors = colors.numpy() - colors = colors.tolist() - - use_colors = [random.choice(colors) for _ in range(num_colors)] - - pallette = torch.tensor([ - [0, 0, 0], - mean_color, - [255, 255, 255], - ] + use_colors, dtype=torch.float32) - chunk = quantize(chunk.squeeze(0), pallette).unsqueeze(0) - - # chunk = torchvision.transforms.functional.equalize(chunk) - # color jitter - if random.random() < 0.5: - chunk = torchvision.transforms.functional.adjust_contrast(chunk, random.uniform(1.0, 1.5)) - if random.random() < 0.5: - chunk = torchvision.transforms.functional.adjust_saturation(chunk, random.uniform(1.0, 2.0)) - # if random.random() < 0.5: - # chunk = torchvision.transforms.functional.adjust_brightness(chunk, random.uniform(0.5, 1.5)) - chunk = random_blur(chunk, p=0.6) - posterized_chunks.append(chunk) - - img = torch.cat(posterized_chunks, dim=0) - img = img.to(orig_dtype) - img = img / 255 - - if neg1_1: - img = img * 2 - 1 - return img +dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD()) # run through an epoch ang check sizes @@ -186,6 +84,7 @@ for epoch in range(args.epochs): for batch in tqdm(dataloader): batch: 'DataLoaderBatchDTO' img_batch = batch.tensor + batch_size, channels, height, width = img_batch.shape # img_batch = color_block_imgs(img_batch, neg1_1=True) @@ -194,10 +93,14 @@ for epoch in range(args.epochs): big_img = torch.cat(chunks, dim=3) big_img = big_img.squeeze(0) - control_chunks = torch.chunk(batch.control_tensor, batch_size, dim=0) + control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0) big_control_img = torch.cat(control_chunks, dim=3) big_control_img = big_control_img.squeeze(0) * 2 - 1 + + # resize control image + big_control_img = torchvision.transforms.Resize((width, height))(big_control_img) + big_img = torch.cat([big_img, big_control_img], dim=2) min_val = big_img.min() @@ -208,9 +111,9 @@ for epoch in range(args.epochs): # convert to image img = transforms.ToPILImage()(big_img) - # show_img(img) + show_img(img) - # time.sleep(1.0) + time.sleep(1.0) # if not last epoch if epoch < args.epochs - 1: trigger_dataloader_setup_epoch(dataloader) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 950d6de4..ec3efca1 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -154,6 +154,7 @@ class AdapterConfig: if num_tokens is None and self.type.startswith('ip'): if self.type == 'ip+': num_tokens = 16 + num_tokens = 16 elif self.type == 'ip': num_tokens = 4 diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 28d52676..273495c3 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -760,11 +760,15 @@ class ClipImageFileItemDTOMixin: # do a flip img = img.transpose(Image.FLIP_TOP_BOTTOM) - # image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data if img.width != img.height: - # resize to the smallest dimension min_size = min(img.width, img.height) - img = img.resize((min_size, min_size), Image.BICUBIC) + if self.dataset_config.square_crop: + # center crop to a square + img = transforms.CenterCrop(min_size)(img) + else: + # image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data + # resize to the smallest dimension + img = img.resize((min_size, min_size), Image.BICUBIC) if self.has_clip_augmentations: self.clip_image_tensor = self.augment_clip_image(img, transform=None)