From 792a5e37e21830d15cd342607c12619805875d5a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 28 Nov 2023 07:34:43 -0700 Subject: [PATCH] Numerous fixes for time sampling. Still not perfect --- extensions_built_in/sd_trainer/SDTrainer.py | 6 +- jobs/process/BaseSDTrainProcess.py | 34 +++--- testing/test_bucket_dataloader.py | 34 +++++- toolkit/config_modules.py | 1 + toolkit/dataloader_mixins.py | 120 +++++++++----------- toolkit/stable_diffusion_model.py | 43 ++++++- toolkit/train_tools.py | 13 ++- 7 files changed, 160 insertions(+), 91 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f33f6b32..69e51fc6 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -34,6 +34,7 @@ class SDTrainer(BaseSDTrainProcess): super().__init__(process_id, job, config, **kwargs) self.assistant_adapter: Union['T2IAdapter', None] self.do_prior_prediction = False + self.do_long_prompts = False if self.train_config.inverted_mask_prior: self.do_prior_prediction = True @@ -126,6 +127,7 @@ class SDTrainer(BaseSDTrainProcess): # we also denoise as the unaugmented tensor is not a noisy diffirental with torch.no_grad(): unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) + unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier target = unaugmented_latents.detach() # Get the target for loss depending on the prediction type @@ -492,7 +494,7 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds = self.sd.encode_prompt( conditioned_prompts, prompt_2, dropout_prob=self.train_config.prompt_dropout_prob, - long_prompts=True).to( + long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) else: @@ -506,7 +508,7 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds = self.sd.encode_prompt( conditioned_prompts, prompt_2, dropout_prob=self.train_config.prompt_dropout_prob, - long_prompts=True).to( + long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 09ec99e2..4a98f17b 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -666,7 +666,7 @@ class BaseSDTrainProcess(BaseTrainProcess): unconditional_imgs = batch.unconditional_tensor unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype) unconditional_latents = self.sd.encode_images(unconditional_imgs) - batch.unconditional_latents = unconditional_latents + batch.unconditional_latents = unconditional_latents * self.train_config.latent_multiplier unaugmented_latents = None if self.train_config.loss_target == 'differential_noise': @@ -715,36 +715,41 @@ class BaseSDTrainProcess(BaseTrainProcess): orig_timesteps = torch.rand((batch_size,), device=latents.device) if self.train_config.content_or_style == 'content': - timesteps = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps'] + timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps'] elif self.train_config.content_or_style == 'style': - timesteps = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps'] + timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps'] - timesteps = value_map( - timesteps, + timestep_indices = value_map( + timestep_indices, 0, self.sd.noise_scheduler.config['num_train_timesteps'] - 1, min_noise_steps, - max_noise_steps + max_noise_steps - 1 ) - timesteps = timesteps.long().clamp( + timestep_indices = timestep_indices.long().clamp( min_noise_steps + 1, max_noise_steps - 1 ) elif self.train_config.content_or_style == 'balanced': if min_noise_steps == max_noise_steps: - timesteps = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps + timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps else: - timesteps = torch.randint( - min_noise_steps, - max_noise_steps, + # todo, some schedulers use indices, otheres use timesteps. Not sure what to do here + timestep_indices = torch.randint( + min_noise_steps + 1, + max_noise_steps - 1, (batch_size,), device=self.device_torch ) - timesteps = timesteps.long() + timestep_indices = timestep_indices.long() else: raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}") + # convert the timestep_indices to a timestep + timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] + timesteps = torch.stack(timesteps, dim=0) + # get noise noise = self.sd.get_latent_noise( height=latents.shape[2], @@ -765,9 +770,10 @@ class BaseSDTrainProcess(BaseTrainProcess): noise = noise * noise_multiplier - img_multiplier = self.train_config.img_multiplier + latents = latents * self.train_config.latent_multiplier - latents = latents * img_multiplier + if batch.unconditional_latents is not None: + batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps) diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index 1b3bae07..e8208107 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -30,9 +30,12 @@ parser.add_argument('--epochs', type=int, default=1) args = parser.parse_args() dataset_folder = args.dataset_folder -resolution = 512 +resolution = 1024 bucket_tolerance = 64 -batch_size = 4 +batch_size = 1 + + +## dataset_config = DatasetConfig( dataset_path=dataset_folder, @@ -41,8 +44,31 @@ dataset_config = DatasetConfig( default_caption='default', buckets=True, bucket_tolerance=bucket_tolerance, - augments=['ColorJitter'], - poi='person' + poi='person', + augmentations=[ + { + 'method': 'RandomBrightnessContrast', + 'brightness_limit': (-0.3, 0.3), + 'contrast_limit': (-0.3, 0.3), + 'brightness_by_max': False, + 'p': 1.0 + }, + { + 'method': 'HueSaturationValue', + 'hue_shift_limit': (-0, 0), + 'sat_shift_limit': (-40, 40), + 'val_shift_limit': (-40, 40), + 'p': 1.0 + }, + # { + # 'method': 'RGBShift', + # 'r_shift_limit': (-20, 20), + # 'g_shift_limit': (-20, 20), + # 'b_shift_limit': (-20, 20), + # 'p': 1.0 + # }, + ] + ) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b52eddfe..b7aedd35 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -193,6 +193,7 @@ class TrainConfig: self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) + self.latent_multiplier = kwargs.get('latent_multiplier', 1.0) self.negative_prompt = kwargs.get('negative_prompt', None) # multiplier applied to loos on regularization images self.reg_weight = kwargs.get('reg_weight', 1.0) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index cca8ced1..4842556a 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: # def get_associated_caption_from_img_path(img_path): - +# https://demo.albumentations.ai/ class Augments: def __init__(self, **kwargs): self.method_name = kwargs.get('method', None) @@ -167,10 +167,11 @@ class BucketsMixin: width = int(file_item.width * file_item.dataset_config.scale) height = int(file_item.height * file_item.dataset_config.scale) + did_process_poi = False if file_item.has_point_of_interest: - # let the poi module handle the bucketing - file_item.setup_poi_bucket() - else: + # Attempt to process the poi if we can. It wont process if the image is smaller than the resolution + did_process_poi = file_item.setup_poi_bucket() + if not did_process_poi: bucket_resolution = get_bucket_for_image_size( width, height, resolution=resolution, @@ -323,7 +324,7 @@ class CaptionProcessingDTOMixin: if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0: # add random triggers - caption = random.choice(self.dataset_config.random_triggers) + ', ' + caption + caption = caption + ', ' + random.choice(self.dataset_config.random_triggers) if self.dataset_config.shuffle_tokens: # shuffle again @@ -803,79 +804,68 @@ class PoiFileItemDTOMixin: self.poi_y = self.height - self.poi_y - self.poi_height def setup_poi_bucket(self: 'FileItemDTO'): - # we are using poi, so we need to calculate the bucket based on the poi - - # TODO this will allow poi to be smaller than resolution. Could affect training image size - poi_resolution = min( - self.dataset_config.resolution, - get_resolution( - self.poi_width * self.dataset_config.scale, - self.poi_height * self.dataset_config.scale - ) - ) - - resolution = min(self.dataset_config.resolution, poi_resolution) - - bucket_tolerance = self.dataset_config.bucket_tolerance initial_width = int(self.width * self.dataset_config.scale) initial_height = int(self.height * self.dataset_config.scale) + # we are using poi, so we need to calculate the bucket based on the poi + + # if img resolution is less than dataset resolution, just return and let the normal bucketing happen + img_resolution = get_resolution(initial_width, initial_height) + if img_resolution <= self.dataset_config.resolution: + return False # will trigger normal bucketing + + bucket_tolerance = self.dataset_config.bucket_tolerance poi_x = int(self.poi_x * self.dataset_config.scale) poi_y = int(self.poi_y * self.dataset_config.scale) poi_width = int(self.poi_width * self.dataset_config.scale) poi_height = int(self.poi_height * self.dataset_config.scale) - # expand poi to fit resolution - if poi_width < resolution: - width_difference = resolution - poi_width - poi_x = poi_x - int(width_difference / 2) - poi_width = resolution - # make sure we dont go out of bounds - if poi_x < 0: + # loop to keep expanding until we are at the proper resolution. This is not ideal, we can probably handle it better + num_loops = 0 + while True: + # crop left + if poi_x > 0: + poi_x = random.randint(0, poi_x) + else: poi_x = 0 - # if total width too much, crop - if poi_x + poi_width > initial_width: - poi_width = initial_width - poi_x - if poi_height < resolution: - height_difference = resolution - poi_height - poi_y = poi_y - int(height_difference / 2) - poi_height = resolution - # make sure we dont go out of bounds - if poi_y < 0: + # crop right + cr_min = poi_x + poi_width + if cr_min < initial_width: + crop_right = random.randint(poi_x + poi_width, initial_width) + else: + crop_right = initial_width + + poi_width = crop_right - poi_x + + if poi_y > 0: + poi_y = random.randint(0, poi_y) + else: poi_y = 0 - # if total height too much, crop - if poi_y + poi_height > initial_height: - poi_height = initial_height - poi_y - # crop left - if poi_x > 0: - crop_left = random.randint(0, poi_x) - else: - crop_left = 0 + if poi_y + poi_height < initial_height: + crop_bottom = random.randint(poi_y + poi_height, initial_height) + else: + crop_bottom = initial_height - # crop right - cr_min = poi_x + poi_width - if cr_min < initial_width: - crop_right = random.randint(poi_x + poi_width, initial_width) - else: - crop_right = initial_width + poi_height = crop_bottom - poi_y + # now we have our random crop, but it may be smaller than resolution. Check and expand if needed + current_resolution = get_resolution(poi_width, poi_height) + if current_resolution >= self.dataset_config.resolution: + # We can break now + break + else: + num_loops += 1 + if num_loops > 100: + print( + f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.") + return False - if poi_y > 0: - crop_top = random.randint(0, poi_y) - else: - crop_top = 0 - - if poi_y + poi_height < initial_height: - crop_bottom = random.randint(poi_y + poi_height, initial_height) - else: - crop_bottom = initial_height - - new_width = crop_right - crop_left - new_height = crop_bottom - crop_top + new_width = poi_width + new_height = poi_height bucket_resolution = get_bucket_for_image_size( new_width, new_height, - resolution=resolution, + resolution=self.dataset_config.resolution, divisibility=bucket_tolerance ) @@ -888,8 +878,10 @@ class PoiFileItemDTOMixin: self.scale_to_height = int(initial_height * max_scale_factor) self.crop_width = bucket_resolution['width'] self.crop_height = bucket_resolution['height'] - self.crop_x = int(crop_left * max_scale_factor) - self.crop_y = int(crop_top * max_scale_factor) + self.crop_x = int(poi_x * max_scale_factor) + self.crop_y = int(poi_y * max_scale_factor) + + return True class ArgBreakMixin: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 8845e30a..9ee4d995 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -193,6 +193,9 @@ class StableDiffusion: device=self.device_torch, torch_dtype=self.torch_dtype, ) + + if 'vae' in load_args and load_args['vae'] is not None: + pipe.vae = load_args['vae'] flush() text_encoders = [pipe.text_encoder, pipe.text_encoder_2] @@ -679,6 +682,18 @@ class StableDiffusion: text_embeddings = text_embeddings.to(self.device_torch) timestep = timestep.to(self.device_torch) + def scale_model_input(model_input, timestep_tensor): + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + out_chunks = [] + for idx in range(model_input.shape[0]): + # if scheduler has step_index + if hasattr(self.noise_scheduler, '_step_index'): + self.noise_scheduler._step_index = None + out_chunks.append( + self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_tensor[idx]) + ) + return torch.cat(out_chunks, dim=0) + if self.is_xl: with torch.no_grad(): # 16, 6 for bs of 4 @@ -691,10 +706,11 @@ class StableDiffusion: if do_classifier_free_guidance: latent_model_input = torch.cat([latents] * 2) + timestep = torch.cat([timestep] * 2) else: latent_model_input = latents - latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) + latent_model_input = scale_model_input(latent_model_input, timestep) added_cond_kwargs = { # todo can we zero here the second text encoder? or match a blank string? @@ -784,10 +800,11 @@ class StableDiffusion: if do_classifier_free_guidance: # if we are doing classifier free guidance, need to double up latent_model_input = torch.cat([latents] * 2) + timestep = torch.cat([timestep] * 2) else: latent_model_input = latents - latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) + latent_model_input = scale_model_input(latent_model_input, timestep) # check if we need to concat timesteps if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: @@ -823,6 +840,23 @@ class StableDiffusion: return noise_pred + def step_scheduler(self, model_input, latent_input, timestep_tensor): + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + for idx in range(model_input.shape[0]): + # Reset it so it is unique for the + if hasattr(self.noise_scheduler, '_step_index'): + self.noise_scheduler._step_index = None + if hasattr(self.noise_scheduler, 'is_scale_input_called'): + self.noise_scheduler.is_scale_input_called = True + out_chunks.append( + self.noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ + 0] + ) + return torch.cat(out_chunks, dim=0) + # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 def diffuse_some_steps( self, @@ -839,6 +873,7 @@ class StableDiffusion: timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] for timestep in tqdm(timesteps_to_run, leave=False): + timestep = timestep.unsqueeze_(0) noise_pred = self.predict_noise( latents, text_embeddings, @@ -847,7 +882,9 @@ class StableDiffusion: add_time_ids=add_time_ids, **kwargs, ) - latents = self.noise_scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + # some schedulers need to run separately, so do that. (euler for example) + + latents = self.step_scheduler(noise_pred, latents, timestep) # if not last step, and bleeding, bleed in some latents if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 267b7588..4ea23b69 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -584,8 +584,13 @@ def encode_prompts_xl( def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None): if max_length is None and not truncate: raise ValueError("max_length must be set if truncate is True") - - tokens = tokens.to(text_encoder.device) + try: + tokens = tokens.to(text_encoder.device) + except Exception as e: + print(e) + print("tokens.device", tokens.device) + print("text_encoder.device", text_encoder.device) + raise e if truncate: return text_encoder(tokens)[0] @@ -771,8 +776,8 @@ def apply_snr_weight( ): # will get it from noise scheduler if exist or will calculate it if not all_snr = get_all_snr(noise_scheduler, loss.device) - - snr = torch.stack([all_snr[t] for t in timesteps]) + step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps] + snr = torch.stack([all_snr[t] for t in step_indices]) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) if fixed: snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr