diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 27346c4c..9e326619 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -152,7 +152,7 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.learnable_snr_gos: # add snr_gamma loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) - if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: + elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: # add snr_gamma loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 2288efe5..f412d1ee 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -13,7 +13,7 @@ from toolkit.basic import value_map from toolkit.config_modules import SliderConfig from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.sd_device_states_presets import get_train_sd_device_state_preset -from toolkit.train_tools import get_torch_dtype, apply_snr_weight +from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_learnable_snr_gos import gc from toolkit import train_tools from toolkit.prompt_utils import \ @@ -35,6 +35,7 @@ adapter_transforms = transforms.Compose([ transforms.ToTensor(), ]) + class TrainSliderProcess(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) @@ -273,7 +274,7 @@ class TrainSliderProcess(BaseSDTrainProcess): ) return adapter_tensors - def hook_train_loop(self, batch): + def hook_train_loop(self, batch: Union['DataLoaderBatchDTO', None]): # set to eval mode self.sd.set_device_state(self.eval_slider_device_state) with torch.no_grad(): @@ -309,7 +310,8 @@ class TrainSliderProcess(BaseSDTrainProcess): if dbr_batch_size != dn.shape[0]: amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size) down_kwargs['down_block_additional_residuals'] = [ - torch.cat([sample.clone()] * amount_to_add) for sample in down_kwargs['down_block_additional_residuals'] + torch.cat([sample.clone()] * amount_to_add) for sample in + down_kwargs['down_block_additional_residuals'] ] return self.sd.predict_noise( latents=dn, @@ -325,6 +327,7 @@ class TrainSliderProcess(BaseSDTrainProcess): with torch.no_grad(): adapter_images = None + # for a complete slider, the batch size is 4 to begin with now true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size from_batch = False @@ -370,7 +373,6 @@ class TrainSliderProcess(BaseSDTrainProcess): 1, self.train_config.max_denoising_steps, (1,) ).item() - # get noise noise = self.sd.get_latent_noise( pixel_height=height, @@ -401,7 +403,6 @@ class TrainSliderProcess(BaseSDTrainProcess): noise_scheduler.set_timesteps(1000) - current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps) current_timestep = noise_scheduler.timesteps[current_timestep_index] @@ -410,6 +411,33 @@ class TrainSliderProcess(BaseSDTrainProcess): denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks] # flush() # 4.2GB to 3GB on 512x512 + mask_multiplier = torch.ones((denoised_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + has_mask = False + if batch and batch.mask_tensor is not None: + with self.timer('get_mask_multiplier'): + # upsampling no supported for bfloat16 + mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() + # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) + mask_multiplier = torch.nn.functional.interpolate( + mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3]) + ) + # expand to match latents + mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) + mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + has_mask = True + + if has_mask: + unmasked_target = get_noise_pred( + prompt_pair.positive_target, # negative prompt + prompt_pair.target_class, # positive prompt + 1, + current_timestep, + denoised_latents + ) + unmasked_target = unmasked_target.detach() + unmasked_target.requires_grad = False + else: + unmasked_target = None # 4.20 GB RAM for 512x512 positive_latents = get_noise_pred( @@ -504,19 +532,30 @@ class TrainSliderProcess(BaseSDTrainProcess): anchor.to("cpu") with torch.no_grad(): - if self.slider_config.high_ram: + if self.slider_config.low_ram: + prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size) + denoised_latent_chunks = denoised_latent_chunks # just to have it in one place + positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0) + neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0) + unconditional_latents_chunks = torch.chunk( + unconditional_latents.detach(), + self.prompt_chunk_size, + dim=0 + ) + mask_multiplier_chunks = torch.chunk(mask_multiplier, self.prompt_chunk_size, dim=0) + if unmasked_target is not None: + unmasked_target_chunks = torch.chunk(unmasked_target, self.prompt_chunk_size, dim=0) + else: + unmasked_target_chunks = [None for _ in range(self.prompt_chunk_size)] + else: # run through in one instance prompt_pair_chunks = [prompt_pair.detach()] denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()] positive_latents_chunks = [positive_latents.detach()] neutral_latents_chunks = [neutral_latents.detach()] unconditional_latents_chunks = [unconditional_latents.detach()] - else: - prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size) - denoised_latent_chunks = denoised_latent_chunks # just to have it in one place - positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0) - neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0) - unconditional_latents_chunks = torch.chunk(unconditional_latents.detach(), self.prompt_chunk_size, dim=0) + mask_multiplier_chunks = [mask_multiplier] + unmasked_target_chunks = [unmasked_target] # flush() assert len(prompt_pair_chunks) == len(denoised_latent_chunks) @@ -528,13 +567,17 @@ class TrainSliderProcess(BaseSDTrainProcess): denoised_latent_chunk, \ positive_latents_chunk, \ neutral_latents_chunk, \ - unconditional_latents_chunk \ + unconditional_latents_chunk, \ + mask_multiplier_chunk, \ + unmasked_target_chunk \ in zip( prompt_pair_chunks, denoised_latent_chunks, positive_latents_chunks, neutral_latents_chunks, unconditional_latents_chunks, + mask_multiplier_chunks, + unmasked_target_chunks ): self.network.multiplier = prompt_pair_chunk.multiplier_list target_latents = get_noise_pred( @@ -568,17 +611,43 @@ class TrainSliderProcess(BaseSDTrainProcess): # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none") + + # do inverted mask to preserve non masked + if has_mask and unmasked_target_chunk is not None: + loss = loss * mask_multiplier_chunk + # match the mask unmasked_target_chunk + mask_target_loss = torch.nn.functional.mse_loss( + target_latents.float(), + unmasked_target_chunk.float(), + reduction="none" + ) + mask_target_loss = mask_target_loss * (1.0 - mask_multiplier_chunk) + loss += mask_target_loss + loss = loss.mean([1, 2, 3]) + if self.train_config.learnable_snr_gos: + if from_batch: + # match batch size + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, + self.train_config.min_snr_gamma) + else: + # match batch size + timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] + # add snr_gamma + loss = apply_learnable_snr_gos(loss, timesteps_index_list, self.snr_gos) if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: if from_batch: # match batch size - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, + self.train_config.min_snr_gamma) else: # match batch size timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, + self.train_config.min_snr_gamma) + loss = loss.mean() * prompt_pair_chunk.weight diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 9fb51b69..4406a85b 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -283,7 +283,7 @@ class SliderConfig: self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) self.use_adapter: bool = kwargs.get('use_adapter', None) # depth self.adapter_img_dir = kwargs.get('adapter_img_dir', None) - self.high_ram = kwargs.get('high_ram', False) + self.low_ram = kwargs.get('low_ram', False) # expand targets if shuffling from toolkit.prompt_utils import get_slider_target_permutations @@ -334,6 +334,7 @@ class DatasetConfig: self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black) + self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1 self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 589a22cd..0326c931 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -18,7 +18,7 @@ from toolkit.buckets import get_bucket_for_image_size from toolkit.metadata import get_meta_for_safetensors from toolkit.prompt_utils import inject_trigger_into_prompt from torchvision import transforms -from PIL import Image, ImageFilter +from PIL import Image, ImageFilter, ImageOps from PIL.ImageOps import exif_transpose import albumentations as A @@ -612,6 +612,8 @@ class MaskFileItemDTOMixin: img = Image.fromarray(np_img) img = img.convert('RGB') + if self.dataset_config.invert_mask: + img = ImageOps.invert(img) w, h = img.size if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 0a732f4e..dbacfa79 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -535,7 +535,7 @@ class StableDiffusion: text_embeddings: Union[PromptEmbeds, None] = None, timestep: Union[int, torch.Tensor] = 1, guidance_scale=7.5, - guidance_rescale=0, # 0.7 sdxl + guidance_rescale=0, add_time_ids=None, conditional_embeddings: Union[PromptEmbeds, None] = None, unconditional_embeddings: Union[PromptEmbeds, None] = None, @@ -674,7 +674,7 @@ class StableDiffusion: add_time_ids=add_time_ids, **kwargs, ) - latents = self.noise_scheduler.step(noise_pred, timestep, latents).prev_sample + latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] # return latents_steps return latents diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index edf1ee7d..004c6503 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -691,12 +691,12 @@ class LearnableSNRGamma: def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'): self.device = device self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler - self.offset = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device)) - self.scale = torch.nn.Parameter(torch.tensor(0.001, dtype=torch.float32, device=device)) - self.gamma = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device)) - self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.1) + self.offset = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device)) + self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device)) + self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device)) + self.optimizer = torch.optim.AdamW([self.offset, self.gamma, self.scale], lr=0.01) self.buffer = [] - self.max_buffer_size = 100 + self.max_buffer_size = 20 def forward(self, loss, timesteps): # do a our train loop for lsnr here and return our values detached