diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index a4b3cf28..704bdcd8 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -69,29 +69,33 @@ class SDTrainer(BaseSDTrainProcess): guidance_scale=1.0, ) flush() - # 9.18 gb - noise = noise.to(self.device_torch, dtype=dtype).detach() + # 9.18 gb + noise = noise.to(self.device_torch, dtype=dtype).detach() - if self.sd.prediction_type == 'v_prediction': - # v-parameterization training - target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) - else: - target = noise + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: - # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) - loss = loss.mean() + loss = loss.mean() - # back propagate loss to free ram - loss.backward() - torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) - flush() + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + loss.backward() + torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + flush() # apply gradients self.optimizer.step() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 5cfa75b6..27acd5d3 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -255,32 +255,37 @@ class BaseSDTrainProcess(BaseTrainProcess): file_path = os.path.join(self.save_root, filename) # prepare meta save_meta = get_meta_for_safetensors(self.meta, self.job.name) - if self.network is not None: - prev_multiplier = self.network.multiplier - self.network.multiplier = 1.0 - if self.network_config.normalize: - # apply the normalization - self.network.apply_stored_normalizer() + if self.network is not None or self.embedding is not None: + if self.network is not None: + prev_multiplier = self.network.multiplier + self.network.multiplier = 1.0 + if self.network_config.normalize: + # apply the normalization + self.network.apply_stored_normalizer() - # if we are doing embedding training as well, add that - embedding_dict = self.embedding.state_dict() if self.embedding else None - self.network.save_weights( - file_path, - dtype=get_torch_dtype(self.save_config.dtype), - metadata=save_meta, - extra_state_dict=embedding_dict - ) - self.network.multiplier = prev_multiplier - # if we have an embedding as well, pair it with the network - elif self.embedding is not None: - # for combo, above will get it - # set current step - self.embedding.step = self.step_num - # change filename to pt if that is set - if self.embed_config.save_format == "pt": - # replace extension - file_path = os.path.splitext(file_path)[0] + ".pt" - self.embedding.save(file_path) + # if we are doing embedding training as well, add that + embedding_dict = self.embedding.state_dict() if self.embedding else None + self.network.save_weights( + file_path, + dtype=get_torch_dtype(self.save_config.dtype), + metadata=save_meta, + extra_state_dict=embedding_dict + ) + self.network.multiplier = prev_multiplier + # if we have an embedding as well, pair it with the network + + # even if added to lora, still save the trigger version + if self.embedding is not None: + emb_filename = f'{self.embed_config.trigger}{step_num}.safetensors' + emb_file_path = os.path.join(self.save_root, emb_filename) + # for combo, above will get it + # set current step + self.embedding.step = self.step_num + # change filename to pt if that is set + if self.embed_config.save_format == "pt": + # replace extension + file_path = os.path.splitext(emb_file_path)[0] + ".pt" + self.embedding.save(file_path) else: self.sd.save( file_path, @@ -316,17 +321,19 @@ class BaseSDTrainProcess(BaseTrainProcess): # return loss return 0.0 - def get_latest_save_path(self): + def get_latest_save_path(self, name=None): + if name == None: + name = self.job.name # get latest saved step if os.path.exists(self.save_root): latest_file = None # pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors - pattern = f"{self.job.name}*.safetensors" + pattern = f"{name}*.safetensors" files = glob.glob(os.path.join(self.save_root, pattern)) if len(files) > 0: latest_file = max(files, key=os.path.getctime) # try pt - pattern = f"{self.job.name}*.pt" + pattern = f"{name}*.pt" files = glob.glob(os.path.join(self.save_root, pattern)) if len(files) > 0: latest_file = max(files, key=os.path.getctime) @@ -548,9 +555,12 @@ class BaseSDTrainProcess(BaseTrainProcess): # we are doing embedding training as well self.embedding = Embedding( sd=self.sd, - embed_config=self.embed_config, - state_dict=extra_weights + embed_config=self.embed_config ) + latest_save_path = self.get_latest_save_path(self.embed_config.trigger) + # load last saved weights + if latest_save_path is not None: + self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) params.append({ 'params': self.embedding.get_trainable_params(), 'lr': self.train_config.embedding_lr @@ -562,7 +572,7 @@ class BaseSDTrainProcess(BaseTrainProcess): sd=self.sd, embed_config=self.embed_config ) - latest_save_path = self.get_latest_save_path() + latest_save_path = self.get_latest_save_path(self.embed_config.trigger) # load last saved weights if latest_save_path is not None: self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 2ef1d3aa..c0c27141 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -50,10 +50,7 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): **kwargs ): """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__( - org_module=org_module, - parent=parent - ) + super().__init__() self.lora_name = lora_name self.scalar = torch.tensor(1.0) diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index 88e98ab3..b881f9f2 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -36,12 +36,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule): # call super of super torch.nn.Module.__init__(self) # call super of - super().__init__( - org_module=org_module, - call_super_init=False, - parent=parent, - **kwargs - ) + super().__init__(call_super_init=False) self.lora_name = lora_name self.lora_dim = lora_dim self.cp = False diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index ea350542..4934ba93 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -52,18 +52,9 @@ class ToolkitModuleMixin: ): if call_super_init: super().__init__(*args, **kwargs) - self.tk_orig_module: torch.nn.Module = kwargs.get('org_module', None) - self.tk_orig_parent = kwargs.get('parent', None) self.is_checkpointing = False self.is_normalizing = False self.normalize_scaler = 1.0 - # see if is conv or linear - self.is_conv = False - self.is_linear = False - if self.tk_orig_module.__class__.__name__ in LINEAR_MODULES: - self.is_linear = True - elif self.tk_orig_module.__class__.__name__ in CONV_MODULES: - self.is_conv = True self._multiplier: Union[float, list, torch.Tensor] = 1.0 # this allows us to set different multipliers on a per item in a batch basis @@ -140,10 +131,6 @@ class ToolkitModuleMixin: lora_output_batch_size = lora_output.size(0) multiplier_batch_size = multiplier.size(0) if lora_output_batch_size != multiplier_batch_size: - print( - f"Warning: lora_output_batch_size {lora_output_batch_size} != multiplier_batch_size {multiplier_batch_size}") - # doing cfg - # should be 1 for if total batch size was 1 num_interleaves = (lora_output_batch_size // 2) // multiplier_batch_size multiplier = multiplier.repeat_interleave(num_interleaves) # multiplier = 1.0