diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 4f172b15..3e25e0c5 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -36,67 +36,8 @@ class SDTrainer(BaseSDTrainProcess): self.sd.text_encoder.train() def hook_train_loop(self, batch): - with torch.no_grad(): - imgs, prompts, dataset_config = batch - - # convert the 0 or 1 for is reg to a bool list - is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])]) - if isinstance(is_reg_list, torch.Tensor): - is_reg_list = is_reg_list.numpy().tolist() - is_reg_list = [bool(x) for x in is_reg_list] - - conditioned_prompts = [] - - for prompt, is_reg in zip(prompts, is_reg_list): - - # make sure the embedding is in the prompts - if self.embedding is not None: - prompt = self.embedding.inject_embedding_to_prompt( - prompt, - expand_token=True, - add_if_not_present=True, - ) - - # make sure trigger is in the prompts if not a regularization run - if self.trigger_word is not None and not is_reg: - prompt = self.sd.inject_trigger_into_prompt( - prompt, - add_if_not_present=True, - ) - conditioned_prompts.append(prompt) - - batch_size = imgs.shape[0] - - dtype = get_torch_dtype(self.train_config.dtype) - imgs = imgs.to(self.device_torch, dtype=dtype) - latents = self.sd.encode_images(imgs) - - noise_scheduler = self.sd.noise_scheduler - optimizer = self.optimizer - lr_scheduler = self.lr_scheduler - - self.sd.noise_scheduler.set_timesteps( - self.train_config.max_denoising_steps, device=self.device_torch - ) - - timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch) - timesteps = timesteps.long() - - # get noise - noise = self.sd.get_latent_noise( - pixel_height=imgs.shape[2], - pixel_width=imgs.shape[3], - batch_size=batch_size, - noise_offset=self.train_config.noise_offset - ).to(self.device_torch, dtype=dtype) - - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # remove grads for these - noisy_latents.requires_grad = False - noise.requires_grad = False - - flush() + dtype = get_torch_dtype(self.train_config.dtype) + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) self.optimizer.zero_grad() @@ -135,7 +76,7 @@ class SDTrainer(BaseSDTrainProcess): if self.sd.prediction_type == 'v_prediction': # v-parameterization training - target = noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) else: target = noise @@ -144,7 +85,7 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() @@ -153,9 +94,9 @@ class SDTrainer(BaseSDTrainProcess): flush() # apply gradients - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() if self.embedding is not None: # Let's make sure we don't update any embedding weights besides the newly added token diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 3865fdb1..dfa0a967 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -283,6 +283,66 @@ class BaseSDTrainProcess(BaseTrainProcess): else: print("load_weights not implemented for non-network models") + def process_general_training_batch(self, batch): + with torch.no_grad(): + imgs, prompts, dataset_config = batch + + # convert the 0 or 1 for is reg to a bool list + is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])]) + if isinstance(is_reg_list, torch.Tensor): + is_reg_list = is_reg_list.numpy().tolist() + is_reg_list = [bool(x) for x in is_reg_list] + + conditioned_prompts = [] + + for prompt, is_reg in zip(prompts, is_reg_list): + + # make sure the embedding is in the prompts + if self.embedding is not None: + prompt = self.embedding.inject_embedding_to_prompt( + prompt, + expand_token=True, + add_if_not_present=True, + ) + + # make sure trigger is in the prompts if not a regularization run + if self.trigger_word is not None and not is_reg: + prompt = self.sd.inject_trigger_into_prompt( + prompt, + add_if_not_present=True, + ) + conditioned_prompts.append(prompt) + + batch_size = imgs.shape[0] + + dtype = get_torch_dtype(self.train_config.dtype) + imgs = imgs.to(self.device_torch, dtype=dtype) + latents = self.sd.encode_images(imgs) + + + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch + ) + + timesteps = torch.randint(0, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch) + timesteps = timesteps.long() + + # get noise + noise = self.sd.get_latent_noise( + pixel_height=imgs.shape[2], + pixel_width=imgs.shape[3], + batch_size=batch_size, + noise_offset=self.train_config.noise_offset + ).to(self.device_torch, dtype=dtype) + + noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps) + + # remove grads for these + noisy_latents.requires_grad = False + noise.requires_grad = False + + return noisy_latents, noise, timesteps, conditioned_prompts, imgs + def run(self): # run base process run BaseTrainProcess.run(self) diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 73d7daf4..6e9aae48 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -169,11 +169,16 @@ class TrainSliderProcess(BaseSDTrainProcess): self.prompt_pairs = prompt_pairs # self.anchor_pairs = anchor_pairs flush() + if self.data_loader is not None: + # we will have images, prep the vae + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) # end hook_before_train_loop def hook_train_loop(self, batch): dtype = get_torch_dtype(self.train_config.dtype) + # get a random pair prompt_pair: EncodedPromptPair = self.prompt_pairs[ torch.randint(0, len(self.prompt_pairs), (1,)).item() @@ -207,55 +212,67 @@ class TrainSliderProcess(BaseSDTrainProcess): ) with torch.no_grad(): - self.sd.noise_scheduler.set_timesteps( - self.train_config.max_denoising_steps, device=self.device_torch - ) - - self.optimizer.zero_grad() - - # ger a random number of steps - timesteps_to = torch.randint( - 1, self.train_config.max_denoising_steps, (1,) - ).item() - # for a complete slider, the batch size is 4 to begin with now true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size + from_batch = False + if batch is not None: + # traing from a batch of images, not generating ourselves + from_batch = True + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) - # get noise - noise = self.sd.get_latent_noise( - pixel_height=height, - pixel_width=width, - batch_size=true_batch_size, - noise_offset=self.train_config.noise_offset, - ).to(self.device_torch, dtype=dtype) + denoised_latent_chunks = [noisy_latents] * self.prompt_chunk_size + denoised_latents = torch.cat(denoised_latent_chunks, dim=0) + current_timestep = timesteps + else: - # get latents - latents = noise * self.sd.noise_scheduler.init_noise_sigma - latents = latents.to(self.device_torch, dtype=dtype) - - with self.network: - assert self.network.is_active - # pass the multiplier list to the network - self.network.multiplier = prompt_pair.multiplier_list - denoised_latents = self.sd.diffuse_some_steps( - latents, # pass simple noise latents - train_tools.concat_prompt_embeddings( - prompt_pair.positive_target, # unconditional - prompt_pair.target_class, # target - self.train_config.batch_size, - ), - start_timesteps=0, - total_timesteps=timesteps_to, - guidance_scale=3, + self.sd.noise_scheduler.set_timesteps( + self.train_config.max_denoising_steps, device=self.device_torch ) - # split the latents into out prompt pair chunks - denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0) + self.optimizer.zero_grad() - noise_scheduler.set_timesteps(1000) + # ger a random number of steps + timesteps_to = torch.randint( + 1, self.train_config.max_denoising_steps, (1,) + ).item() - current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps) - current_timestep = noise_scheduler.timesteps[current_timestep_index] + + # get noise + noise = self.sd.get_latent_noise( + pixel_height=height, + pixel_width=width, + batch_size=true_batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) + + with self.network: + assert self.network.is_active + # pass the multiplier list to the network + self.network.multiplier = prompt_pair.multiplier_list + denoised_latents = self.sd.diffuse_some_steps( + latents, # pass simple noise latents + train_tools.concat_prompt_embeddings( + prompt_pair.positive_target, # unconditional + prompt_pair.target_class, # target + self.train_config.batch_size, + ), + start_timesteps=0, + total_timesteps=timesteps_to, + guidance_scale=3, + ) + + noise_scheduler.set_timesteps(1000) + + # split the latents into out prompt pair chunks + denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0) + denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks] + + current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps) + current_timestep = noise_scheduler.timesteps[current_timestep_index] # flush() # 4.2GB to 3GB on 512x512 @@ -267,6 +284,7 @@ class TrainSliderProcess(BaseSDTrainProcess): current_timestep, denoised_latents ) + positive_latents = positive_latents.detach() positive_latents.requires_grad = False positive_latents_chunks = torch.chunk(positive_latents, self.prompt_chunk_size, dim=0) @@ -277,6 +295,7 @@ class TrainSliderProcess(BaseSDTrainProcess): current_timestep, denoised_latents ) + neutral_latents = neutral_latents.detach() neutral_latents.requires_grad = False neutral_latents_chunks = torch.chunk(neutral_latents, self.prompt_chunk_size, dim=0) @@ -287,9 +306,12 @@ class TrainSliderProcess(BaseSDTrainProcess): current_timestep, denoised_latents ) + unconditional_latents = unconditional_latents.detach() unconditional_latents.requires_grad = False unconditional_latents_chunks = torch.chunk(unconditional_latents, self.prompt_chunk_size, dim=0) + denoised_latents = denoised_latents.detach() + flush() # 4.2GB to 3GB on 512x512 # 4.20 GB RAM for 512x512 @@ -402,10 +424,14 @@ class TrainSliderProcess(BaseSDTrainProcess): loss = loss.mean([1, 2, 3]) if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: - # match batch size - timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] - # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma) + if from_batch: + # match batch size + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + else: + # match batch size + timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() * prompt_pair_chunk.weight @@ -427,7 +453,7 @@ class TrainSliderProcess(BaseSDTrainProcess): positive_latents, neutral_latents, unconditional_latents, - latents + # latents ) # move back to cpu prompt_pair.to("cpu") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 85fa2823..1bff598d 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -77,6 +77,7 @@ class TrainConfig: self.skip_first_sample = kwargs.get('skip_first_sample', False) self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) self.weight_jitter = kwargs.get('weight_jitter', 0.0) + self.merge_network_on_save = kwargs.get('merge_network_on_save', False) class ModelConfig: diff --git a/toolkit/llvae.py b/toolkit/llvae.py index cfb8ca27..aedbcb61 100644 --- a/toolkit/llvae.py +++ b/toolkit/llvae.py @@ -101,14 +101,16 @@ if __name__ == '__main__': from PIL import Image import torchvision.transforms as transforms user_path = os.path.expanduser('~') + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 input_path = os.path.join(user_path, "Pictures/sample_2_512.png") output_path = os.path.join(user_path, "Pictures/sample_2_512_llvae.png") img = Image.open(input_path) img_tensor = transforms.ToTensor()(img) - img_tensor = img_tensor.unsqueeze(0) + img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype) print("input_shape: ", list(img_tensor.shape)) - vae = LosslessLatentVAE(in_channels=3, latent_depth=8) + vae = LosslessLatentVAE(in_channels=3, latent_depth=8, dtype=dtype).to(device=device, dtype=dtype) latent = vae.encode(img_tensor) print("latent_shape: ", list(latent.shape)) out_tensor = vae.decode(latent) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index cbd0d4fb..f4b7bb32 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -152,6 +152,11 @@ class StableDiffusion: steps_offset=1 ) + # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why + scheduler.betas = scheduler.betas.to(self.device_torch) + scheduler.alphas = scheduler.alphas.to(self.device_torch) + scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(self.device_torch) + model_path = self.model_config.name_or_path if 'civitai.com' in self.model_config.name_or_path: # load is a civit ai model, use the loader. @@ -323,6 +328,13 @@ class StableDiffusion: # todo do we disable text encoder here as well if disabled for model, or only do that for training? if self.is_xl: + # fix guidance rescale for sdxl + # was trained on 0.7 (I believe) + + grs = gen_config.guidance_rescale + if grs is None or grs < 0.00001: + grs = 0.7 + img = pipeline( prompt=gen_config.prompt, prompt_2=gen_config.prompt_2, @@ -332,7 +344,7 @@ class StableDiffusion: width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, - guidance_rescale=gen_config.guidance_rescale, + guidance_rescale=grs, ).images[0] else: img = pipeline( @@ -619,6 +631,27 @@ class StableDiffusion: return latents + def decode_latents( + self, + latents: torch.Tensor, + device=None, + dtype=None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(self.device) + latents = latents.to(device, dtype=dtype) + latents = latents / 0.18215 + images = self.vae.decode(latents).sample + images = images.to(device, dtype=dtype) + + return images + def encode_image_prompt_pairs( self, prompt_list: List[str],