From 833c833f28e1c3262eed140e48e2e6f38bcf255e Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 27 May 2024 10:50:24 -0600 Subject: [PATCH] WIP on SAFE encoder. Work on fp16 training improvements. Various other tweaks and improvements --- extensions_built_in/sd_trainer/SDTrainer.py | 27 +++++++- testing/test_bucket_dataloader.py | 63 +++++++++++++------ toolkit/config_modules.py | 1 + toolkit/custom_adapter.py | 5 +- toolkit/dataloader_mixins.py | 17 ++++- toolkit/ip_adapter.py | 17 ++--- toolkit/models/ilora.py | 17 +++-- .../models/size_agnostic_feature_encoder.py | 9 ++- toolkit/optimizer.py | 20 +++--- 9 files changed, 127 insertions(+), 49 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index a9ea4e4d..d08c9e91 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -31,6 +31,7 @@ from jobs.process import BaseSDTrainProcess from torchvision import transforms + def flush(): torch.cuda.empty_cache() gc.collect() @@ -55,6 +56,17 @@ class SDTrainer(BaseSDTrainProcess): self.negative_prompt_pool: Union[List[str], None] = None self.batch_negative_prompt: Union[List[str], None] = None + self.scaler = torch.cuda.amp.GradScaler() + + # patch the scaler to allow fp16 training + org_unscale_grads = self.scaler._unscale_grads_ + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + self.scaler._unscale_grads_ = _unscale_grads_replacer + + self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" + + def before_model_load(self): pass @@ -1401,6 +1413,7 @@ class SDTrainer(BaseSDTrainProcess): print("loss is nan") loss = torch.zeros_like(loss).requires_grad_(True) + with self.timer('backward'): # todo we have multiplier seperated. works for now as res are not in same batch, but need to change loss = loss * loss_multiplier.mean() @@ -1410,7 +1423,10 @@ class SDTrainer(BaseSDTrainProcess): # 0.0 for the backward pass and the gradients will be 0.0 # I spent weeks on fighting this. DON'T DO IT # with fsdp_overlap_step_with_backward(): - loss.backward() + if self.is_bfloat: + loss.backward() + else: + self.scaler.scale(loss).backward() # flush() if not self.is_grad_accumulation_step: @@ -1423,8 +1439,13 @@ class SDTrainer(BaseSDTrainProcess): torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) # only step if we are not accumulating with self.timer('optimizer_step'): - # apply gradients - self.optimizer.step() + if self.is_bfloat: + self.optimizer.step() + else: + # apply gradients + self.scaler.step(self.optimizer) + self.scaler.update() + # self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) else: # gradient accumulation. Just a place for breakpoint diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index eeac0d17..f830517d 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -48,28 +48,55 @@ dataset_config = DatasetConfig( buckets=True, bucket_tolerance=bucket_tolerance, # poi='person', + shuffle_augmentations=True, # augmentations=[ # { - # 'method': 'RandomBrightnessContrast', - # 'brightness_limit': (-0.3, 0.3), - # 'contrast_limit': (-0.3, 0.3), - # 'brightness_by_max': False, - # 'p': 1.0 + # 'method': 'GaussianBlur', + # 'blur_limit': (1, 16), + # 'sigma_limit': (0, 8), + # 'p': 0.8 # }, # { - # 'method': 'HueSaturationValue', - # 'hue_shift_limit': (-0, 0), - # 'sat_shift_limit': (-40, 40), - # 'val_shift_limit': (-40, 40), - # 'p': 1.0 + # 'method': 'ImageCompression', + # 'quality_lower': 10, + # 'quality_upper': 100, + # 'compression_type': 0, + # 'p': 0.8 + # }, + # { + # 'method': 'ImageCompression', + # 'quality_lower': 20, + # 'quality_upper': 100, + # 'compression_type': 1, + # 'p': 0.8 + # }, + # { + # 'method': 'RingingOvershoot', + # 'blur_limit': (3, 35), + # 'cutoff': (0.7, 1.96), + # 'p': 0.8 + # }, + # { + # 'method': 'GaussNoise', + # 'var_limit': (0, 300), + # 'per_channel': True, + # 'mean': 0.0, + # 'p': 0.8 + # }, + # { + # 'method': 'GlassBlur', + # 'sigma': 0.6, + # 'max_delta': 7, + # 'iterations': 2, + # 'mode': 'fast', + # 'p': 0.8 + # }, + # { + # 'method': 'Downscale', + # 'scale_max': 0.5, + # 'interpolation': 'cv2.INTER_CUBIC', + # 'p': 0.8 # }, - # { - # 'method': 'RGBShift', - # 'r_shift_limit': (-20, 20), - # 'g_shift_limit': (-20, 20), - # 'b_shift_limit': (-20, 20), - # 'p': 1.0 - # }, # ] @@ -100,7 +127,7 @@ for epoch in range(args.epochs): show_img(img) - # time.sleep(0.1) + # time.sleep(1.0) # if not last epoch if epoch < args.epochs - 1: trigger_dataloader_setup_epoch(dataloader) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 4944c89a..a6b9a9fb 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -529,6 +529,7 @@ class DatasetConfig: self.num_workers: int = kwargs.get('num_workers', 4) self.prefetch_factor: int = kwargs.get('prefetch_factor', 2) self.extra_values: List[float] = kwargs.get('extra_values', []) + self.square_crop: bool = kwargs.get('square_crop', False) def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index a414e69b..32e6984b 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -388,7 +388,10 @@ class CustomAdapter(torch.nn.Module): self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict) if 'ilora' in state_dict: - self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict) + try: + self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict) + except Exception as e: + print(e) pass diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 2d55e563..6e7b13bc 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -203,7 +203,22 @@ class BucketsMixin: if file_item.has_point_of_interest: # Attempt to process the poi if we can. It wont process if the image is smaller than the resolution did_process_poi = file_item.setup_poi_bucket() - if not did_process_poi: + if self.dataset_config.square_crop: + # we scale first so smallest size matches resolution + scale_factor_x = resolution / width + scale_factor_y = resolution / height + scale_factor = max(scale_factor_x, scale_factor_y) + file_item.scale_to_width = math.ceil(width * scale_factor) + file_item.scale_to_height = math.ceil(height * scale_factor) + file_item.crop_width = resolution + file_item.crop_height = resolution + if width > height: + file_item.crop_x = int(file_item.scale_to_width / 2 - resolution / 2) + file_item.crop_y = 0 + else: + file_item.crop_x = 0 + file_item.crop_y = int(file_item.scale_to_height / 2 - resolution / 2) + elif not did_process_poi: bucket_resolution = get_bucket_for_image_size( width, height, resolution=resolution, diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 830bba8b..2b39e9ae 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -365,14 +365,15 @@ class IPAdapter(torch.nn.Module): input_size=preprocessor_input_size, clip_input_size=self.image_encoder.config.image_size, ) - if 'height' in self.clip_image_processor.size: - self.input_size = self.clip_image_processor.size['height'] - elif hasattr(self.clip_image_processor, 'crop_size'): - self.input_size = self.clip_image_processor.crop_size['height'] - elif 'shortest_edge' in self.clip_image_processor.size.keys(): - self.input_size = self.clip_image_processor.size['shortest_edge'] - else: - raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}") + if not self.config.image_encoder_arch == 'safe': + if 'height' in self.clip_image_processor.size: + self.input_size = self.clip_image_processor.size['height'] + elif hasattr(self.clip_image_processor, 'crop_size'): + self.input_size = self.clip_image_processor.crop_size['height'] + elif 'shortest_edge' in self.clip_image_processor.size.keys(): + self.input_size = self.clip_image_processor.size['shortest_edge'] + else: + raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}") self.current_scale = 1.0 self.is_active = True is_pixart = sd.is_pixart diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 6bfbf0ff..03459ba8 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -21,19 +21,24 @@ class ILoRAProjModule(torch.nn.Module): self.num_modules = num_modules self.num_dim = dim - self.norm = torch.nn.LayerNorm(embeddings_dim) self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(embeddings_dim), torch.nn.Linear(embeddings_dim, embeddings_dim * 2), torch.nn.GELU(), - torch.nn.Linear(embeddings_dim * 2, num_modules * dim), + torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 2), + torch.nn.LayerNorm(embeddings_dim * 2), + + torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 4), + torch.nn.GELU(), + torch.nn.Linear(embeddings_dim * 4, num_modules * dim), + torch.nn.LayerNorm(num_modules * dim), ) # Initialize the last linear layer weights near zero - torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01) - torch.nn.init.zeros_(self.proj[2].bias) + torch.nn.init.uniform_(self.proj[-2].weight, a=-0.01, b=0.01) + torch.nn.init.zeros_(self.proj[-2].bias) def forward(self, x): - x = self.norm(x) x = self.proj(x) x = x.reshape(-1, self.num_modules, self.num_dim) return x @@ -71,6 +76,8 @@ class InstantLoRAMidModule(torch.nn.Module): # reshape if needed if len(x.shape) == 3: scaler = scaler.unsqueeze(1) + if len(x.shape) == 4: + scaler = scaler.unsqueeze(-1).unsqueeze(-1) except Exception as e: print(e) print(x.shape) diff --git a/toolkit/models/size_agnostic_feature_encoder.py b/toolkit/models/size_agnostic_feature_encoder.py index 15f7a439..a716aec5 100644 --- a/toolkit/models/size_agnostic_feature_encoder.py +++ b/toolkit/models/size_agnostic_feature_encoder.py @@ -20,11 +20,11 @@ class SAFEReducerBlock(nn.Module): self.reducer = nn.Sequential( nn.Conv2d(channels, channels, kernel_size=3, padding=1), - nn.BatchNorm2d(channels), activation(), + nn.BatchNorm2d(channels), nn.Conv2d(channels, channels, kernel_size=3, padding=1), - nn.BatchNorm2d(channels), activation(), + nn.BatchNorm2d(channels), nn.AvgPool2d(kernel_size=2, stride=2), ) self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2) @@ -227,6 +227,7 @@ class SAFEVMConfig: self.reducer_channels = reducer_channels self.channels = channels self.downscale_factor = downscale_factor + self.image_size = 224 self.hidden_size = num_vectors self.projection_dim = num_vectors @@ -242,7 +243,9 @@ class SAFEVMReturn: class SAFEVisionModel(SizeAgnosticFeatureEncoder): def __init__(self, **kwargs): self.config = SAFEVMConfig(**kwargs) - super().__init__(**kwargs) + self.image_size = None + # super().__init__(**kwargs) + super(SAFEVisionModel, self).__init__(**kwargs) @classmethod def from_pretrained(cls, *args, **kwargs): diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index 5d2dddee..f61fea88 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -20,12 +20,12 @@ def get_optimizer( # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0 use_lr = 1.0 if lower_type.endswith('lion'): - optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params) + optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) elif lower_type.endswith('adam'): - optimizer = dadaptation.DAdaptLion(params, lr=use_lr, **optimizer_params) + optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params) elif lower_type == 'dadaptation': # backwards compatibility - optimizer = dadaptation.DAdaptAdam(params, lr=use_lr, **optimizer_params) + optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params) # warn user that dadaptation is deprecated print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.") elif lower_type.startswith("prodigy"): @@ -40,22 +40,22 @@ def get_optimizer( print(f"Using lr {use_lr}") # let net be the neural network you want to train # you can choose weight decay value based on your problem, 0 by default - optimizer = Prodigy(params, lr=use_lr, **optimizer_params) + optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params) elif lower_type.endswith("8bit"): import bitsandbytes if lower_type == "adam8bit": - return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, **optimizer_params) + return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) elif lower_type == "adamw8bit": - return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, **optimizer_params) + return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) elif lower_type == "lion8bit": return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) else: raise ValueError(f'Unknown optimizer type {optimizer_type}') elif lower_type == 'adam': - optimizer = torch.optim.Adam(params, lr=float(learning_rate), **optimizer_params) + optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) elif lower_type == 'adamw': - optimizer = torch.optim.AdamW(params, lr=float(learning_rate), **optimizer_params) + optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) elif lower_type == 'lion': try: from lion_pytorch import Lion @@ -63,7 +63,7 @@ def get_optimizer( except ImportError: raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch") elif lower_type == 'adagrad': - optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params) + optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) elif lower_type == 'adafactor': # hack in stochastic rounding if 'relative_step' not in optimizer_params: @@ -72,7 +72,7 @@ def get_optimizer( optimizer_params['scale_parameter'] = True if 'warmup_init' not in optimizer_params: optimizer_params['warmup_init'] = False - optimizer = Adafactor(params, lr=float(learning_rate), **optimizer_params) + optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) from toolkit.util.adafactor_stochastic_rounding import step_adafactor optimizer.step = step_adafactor.__get__(optimizer, Adafactor) else: