From be804c9cf5bb760910b52da9441e5e5b8c33c300 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 9 Sep 2023 12:02:07 -0600 Subject: [PATCH] Save embeddings as their trigger to match auto and comfy style loading. Also, FINALLY found why gradients were wonkey and fixed it. The root problem is dropping out of network state before backward pass. --- extensions_built_in/sd_trainer/SDTrainer.py | 38 ++++++----- jobs/process/BaseSDTrainProcess.py | 72 ++++++++++++--------- toolkit/lora_special.py | 5 +- toolkit/lycoris_special.py | 7 +- toolkit/network_mixins.py | 13 ---- 5 files changed, 64 insertions(+), 71 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index a4b3cf28..704bdcd8 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -69,29 +69,33 @@ class SDTrainer(BaseSDTrainProcess): guidance_scale=1.0, ) flush() - # 9.18 gb - noise = noise.to(self.device_torch, dtype=dtype).detach() + # 9.18 gb + noise = noise.to(self.device_torch, dtype=dtype).detach() - if self.sd.prediction_type == 'v_prediction': - # v-parameterization training - target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) - else: - target = noise + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: - # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) - loss = loss.mean() + loss = loss.mean() - # back propagate loss to free ram - loss.backward() - torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) - flush() + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + loss.backward() + torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + flush() # apply gradients self.optimizer.step() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 5cfa75b6..27acd5d3 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -255,32 +255,37 @@ class BaseSDTrainProcess(BaseTrainProcess): file_path = os.path.join(self.save_root, filename) # prepare meta save_meta = get_meta_for_safetensors(self.meta, self.job.name) - if self.network is not None: - prev_multiplier = self.network.multiplier - self.network.multiplier = 1.0 - if self.network_config.normalize: - # apply the normalization - self.network.apply_stored_normalizer() + if self.network is not None or self.embedding is not None: + if self.network is not None: + prev_multiplier = self.network.multiplier + self.network.multiplier = 1.0 + if self.network_config.normalize: + # apply the normalization + self.network.apply_stored_normalizer() - # if we are doing embedding training as well, add that - embedding_dict = self.embedding.state_dict() if self.embedding else None - self.network.save_weights( - file_path, - dtype=get_torch_dtype(self.save_config.dtype), - metadata=save_meta, - extra_state_dict=embedding_dict - ) - self.network.multiplier = prev_multiplier - # if we have an embedding as well, pair it with the network - elif self.embedding is not None: - # for combo, above will get it - # set current step - self.embedding.step = self.step_num - # change filename to pt if that is set - if self.embed_config.save_format == "pt": - # replace extension - file_path = os.path.splitext(file_path)[0] + ".pt" - self.embedding.save(file_path) + # if we are doing embedding training as well, add that + embedding_dict = self.embedding.state_dict() if self.embedding else None + self.network.save_weights( + file_path, + dtype=get_torch_dtype(self.save_config.dtype), + metadata=save_meta, + extra_state_dict=embedding_dict + ) + self.network.multiplier = prev_multiplier + # if we have an embedding as well, pair it with the network + + # even if added to lora, still save the trigger version + if self.embedding is not None: + emb_filename = f'{self.embed_config.trigger}{step_num}.safetensors' + emb_file_path = os.path.join(self.save_root, emb_filename) + # for combo, above will get it + # set current step + self.embedding.step = self.step_num + # change filename to pt if that is set + if self.embed_config.save_format == "pt": + # replace extension + file_path = os.path.splitext(emb_file_path)[0] + ".pt" + self.embedding.save(file_path) else: self.sd.save( file_path, @@ -316,17 +321,19 @@ class BaseSDTrainProcess(BaseTrainProcess): # return loss return 0.0 - def get_latest_save_path(self): + def get_latest_save_path(self, name=None): + if name == None: + name = self.job.name # get latest saved step if os.path.exists(self.save_root): latest_file = None # pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors - pattern = f"{self.job.name}*.safetensors" + pattern = f"{name}*.safetensors" files = glob.glob(os.path.join(self.save_root, pattern)) if len(files) > 0: latest_file = max(files, key=os.path.getctime) # try pt - pattern = f"{self.job.name}*.pt" + pattern = f"{name}*.pt" files = glob.glob(os.path.join(self.save_root, pattern)) if len(files) > 0: latest_file = max(files, key=os.path.getctime) @@ -548,9 +555,12 @@ class BaseSDTrainProcess(BaseTrainProcess): # we are doing embedding training as well self.embedding = Embedding( sd=self.sd, - embed_config=self.embed_config, - state_dict=extra_weights + embed_config=self.embed_config ) + latest_save_path = self.get_latest_save_path(self.embed_config.trigger) + # load last saved weights + if latest_save_path is not None: + self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) params.append({ 'params': self.embedding.get_trainable_params(), 'lr': self.train_config.embedding_lr @@ -562,7 +572,7 @@ class BaseSDTrainProcess(BaseTrainProcess): sd=self.sd, embed_config=self.embed_config ) - latest_save_path = self.get_latest_save_path() + latest_save_path = self.get_latest_save_path(self.embed_config.trigger) # load last saved weights if latest_save_path is not None: self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 2ef1d3aa..c0c27141 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -50,10 +50,7 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): **kwargs ): """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__( - org_module=org_module, - parent=parent - ) + super().__init__() self.lora_name = lora_name self.scalar = torch.tensor(1.0) diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index 88e98ab3..b881f9f2 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -36,12 +36,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule): # call super of super torch.nn.Module.__init__(self) # call super of - super().__init__( - org_module=org_module, - call_super_init=False, - parent=parent, - **kwargs - ) + super().__init__(call_super_init=False) self.lora_name = lora_name self.lora_dim = lora_dim self.cp = False diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index ea350542..4934ba93 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -52,18 +52,9 @@ class ToolkitModuleMixin: ): if call_super_init: super().__init__(*args, **kwargs) - self.tk_orig_module: torch.nn.Module = kwargs.get('org_module', None) - self.tk_orig_parent = kwargs.get('parent', None) self.is_checkpointing = False self.is_normalizing = False self.normalize_scaler = 1.0 - # see if is conv or linear - self.is_conv = False - self.is_linear = False - if self.tk_orig_module.__class__.__name__ in LINEAR_MODULES: - self.is_linear = True - elif self.tk_orig_module.__class__.__name__ in CONV_MODULES: - self.is_conv = True self._multiplier: Union[float, list, torch.Tensor] = 1.0 # this allows us to set different multipliers on a per item in a batch basis @@ -140,10 +131,6 @@ class ToolkitModuleMixin: lora_output_batch_size = lora_output.size(0) multiplier_batch_size = multiplier.size(0) if lora_output_batch_size != multiplier_batch_size: - print( - f"Warning: lora_output_batch_size {lora_output_batch_size} != multiplier_batch_size {multiplier_batch_size}") - # doing cfg - # should be 1 for if total batch size was 1 num_interleaves = (lora_output_batch_size // 2) // multiplier_batch_size multiplier = multiplier.repeat_interleave(num_interleaves) # multiplier = 1.0