From 61badf85a765d33250aabd5ea9aeb62434153e99 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 17 Sep 2023 15:56:43 -0600 Subject: [PATCH] t2i training working from what I can tell at least --- extensions_built_in/sd_trainer/SDTrainer.py | 75 ++--- jobs/process/BaseSDTrainProcess.py | 303 +++++++++++--------- toolkit/config_modules.py | 1 + toolkit/saving.py | 2 +- toolkit/stable_diffusion_model.py | 7 +- 5 files changed, 214 insertions(+), 174 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 4be9bd5c..ed4625f8 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -54,16 +54,20 @@ class SDTrainer(BaseSDTrainProcess): if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)): adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext)) break - + width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height adapter_tensors = [] # load images with torch transforms - for adapter_image in adapter_images: + for idx, adapter_image in enumerate(adapter_images): img = Image.open(adapter_image) + # resize to match batch shape + img = img.resize((width, height)) img = adapter_transforms(img) adapter_tensors.append(img) # stack them - adapter_tensors = torch.stack(adapter_tensors) + adapter_tensors = torch.stack(adapter_tensors).to( + self.device_torch, dtype=get_torch_dtype(self.train_config.dtype) + ) return adapter_tensors def hook_train_loop(self, batch): @@ -79,8 +83,8 @@ class SDTrainer(BaseSDTrainProcess): adapter_images = self.get_adapter_images(batch) # not 100% sure what this does. But they do it here # https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170 - sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) - noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + # sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) + # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) # flush() self.optimizer.zero_grad() @@ -126,39 +130,38 @@ class SDTrainer(BaseSDTrainProcess): **pred_kwargs ) - if self.adapter: - # todo, diffusers does this on t2i training, is it better approach? - # Denoise the latents - denoised_latents = noise_pred * (-sigmas) + noisy_latents - weighing = sigmas ** -2.0 - - # Get the target for loss depending on the prediction type - if self.sd.noise_scheduler.config.prediction_type == "epsilon": - target = batch.latents # we are computing loss against denoise latents - elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": - target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") - - # MSE loss - loss = torch.mean( - (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1), - dim=1, - ) + # if self.adapter: + # # todo, diffusers does this on t2i training, is it better approach? + # # Denoise the latents + # denoised_latents = noise_pred * (-sigmas) + noisy_latents + # weighing = sigmas ** -2.0 + # + # # Get the target for loss depending on the prediction type + # if self.sd.noise_scheduler.config.prediction_type == "epsilon": + # target = batch.latents # we are computing loss against denoise latents + # elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": + # target = self.sd.noise_scheduler.get_velocity(batch.latents, noise, timesteps) + # else: + # raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") + # + # # MSE loss + # loss = torch.mean( + # (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1), + # dim=1, + # ) + # else: + noise = noise.to(self.device_torch, dtype=dtype).detach() + if self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) else: - noise = noise.to(self.device_torch, dtype=dtype).detach() - if self.sd.prediction_type == 'v_prediction': - # v-parameterization training - target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) - else: - target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + target = noise + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - # TODO: I think the sigma method does not need this. Check - if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: - # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 5f0049af..90f11354 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -160,6 +160,11 @@ class BaseSDTrainProcess(BaseTrainProcess): train_embedding=self.embed_config is not None, ) + # fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc) + self.is_fine_tuning = True + if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None: + self.is_fine_tuning = False + def sample(self, step=None, is_first=False): sample_folder = os.path.join(self.save_root, 'samples') gen_img_config_list = [] @@ -194,6 +199,10 @@ class BaseSDTrainProcess(BaseTrainProcess): prompt, self.trigger_word, add_if_not_present=False ) + extra_args = {} + if self.adapter_config is not None: + extra_args['adapter_image_path'] = self.adapter_config.test_img_path + gen_img_config_list.append(GenerateImageConfig( prompt=prompt, # it will autoparse the prompt width=sample_config.width, @@ -206,6 +215,7 @@ class BaseSDTrainProcess(BaseTrainProcess): network_multiplier=sample_config.network_multiplier, output_path=output_path, output_ext=sample_config.ext, + **extra_args )) # send to be generated @@ -287,8 +297,15 @@ class BaseSDTrainProcess(BaseTrainProcess): file_path = os.path.join(self.save_root, filename) # prepare meta save_meta = get_meta_for_safetensors(self.meta, self.job.name) - if self.network is not None or self.embedding is not None: + if self.network is not None or self.embedding is not None or self.adapter is not None: if self.network is not None: + lora_name = self.job.name + if self.adapter_config is not None or self.embedding is not None: + # add _lora to name + lora_name += '_LoRA' + + filename = f'{lora_name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) prev_multiplier = self.network.multiplier self.network.multiplier = 1.0 if self.network_config.normalize: @@ -318,15 +335,23 @@ class BaseSDTrainProcess(BaseTrainProcess): # replace extension emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt" self.embedding.save(emb_file_path) - elif self.adapter is not None: - # save adapter - state_dict = self.adapter.state_dict() - save_t2i_from_diffusers( - state_dict, - output_file=file_path, - meta=save_meta, - dtype=get_torch_dtype(self.save_config.dtype) - ) + + if self.adapter is not None: + adapter_name = self.job.name + if self.network_config is not None or self.embedding is not None: + # add _lora to name + adapter_name += '_t2i' + + filename = f'{adapter_name}{step_num}.safetensors' + file_path = os.path.join(self.save_root, filename) + # save adapter + state_dict = self.adapter.state_dict() + save_t2i_from_diffusers( + state_dict, + output_file=file_path, + meta=save_meta, + dtype=get_torch_dtype(self.save_config.dtype) + ) else: self.sd.save( file_path, @@ -362,14 +387,14 @@ class BaseSDTrainProcess(BaseTrainProcess): # return loss return 0.0 - def get_latest_save_path(self, name=None): + def get_latest_save_path(self, name=None, post=''): if name == None: name = self.job.name # get latest saved step if os.path.exists(self.save_root): latest_file = None # pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors - pattern = f"{name}*.safetensors" + pattern = f"{name}*{post}.safetensors" files = glob.glob(os.path.join(self.save_root, pattern)) if len(files) > 0: latest_file = max(files, key=os.path.getctime) @@ -399,17 +424,19 @@ class BaseSDTrainProcess(BaseTrainProcess): print("load_weights not implemented for non-network models") return None - def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): - sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype) - schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, ) - timesteps = timesteps.to(self.device_torch, ) - - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < n_dim: - sigma = sigma.unsqueeze(-1) - return sigma + # def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): + # self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch) + # sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype) + # schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, ) + # timesteps = timesteps.to(self.device_torch, ) + # + # # step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + # step_indices = [t for t in timesteps] + # + # sigma = sigmas[step_indices].flatten() + # while len(sigma.shape) < n_dim: + # sigma = sigma.unsqueeze(-1) + # return sigma def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): with torch.no_grad(): @@ -583,54 +610,52 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.datasets_reg is not None: self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.sd) + params = [] + if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None: + if self.network_config is not None: + # TODO should we completely switch to LycorisSpecialNetwork? - if self.network_config is not None: - # TODO should we completely switch to LycorisSpecialNetwork? + is_lycoris = False + # default to LoCON if there are any conv layers or if it is named + NetworkClass = LoRASpecialNetwork + if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris': + NetworkClass = LycorisSpecialNetwork + is_lycoris = True - is_lycoris = False - # default to LoCON if there are any conv layers or if it is named - NetworkClass = LoRASpecialNetwork - if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris': - NetworkClass = LycorisSpecialNetwork - is_lycoris = True + # if is_lycoris: + # preset = PRESET['full'] + # NetworkClass.apply_preset(preset) - # if is_lycoris: - # preset = PRESET['full'] - # NetworkClass.apply_preset(preset) + self.network = NetworkClass( + text_encoder=text_encoder, + unet=unet, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl, + is_v2=self.model_config.is_v2, + dropout=self.network_config.dropout + ) - self.network = NetworkClass( - text_encoder=text_encoder, - unet=unet, - lora_dim=self.network_config.linear, - multiplier=1.0, - alpha=self.network_config.linear_alpha, - train_unet=self.train_config.train_unet, - train_text_encoder=self.train_config.train_text_encoder, - conv_lora_dim=self.network_config.conv, - conv_alpha=self.network_config.conv_alpha, - is_sdxl=self.model_config.is_xl, - is_v2=self.model_config.is_v2, - dropout=self.network_config.dropout - ) + self.network.force_to(self.device_torch, dtype=dtype) + # give network to sd so it can use it + self.sd.network = self.network + self.network._update_torch_multiplier() - self.network.force_to(self.device_torch, dtype=dtype) - # give network to sd so it can use it - self.sd.network = self.network - self.network._update_torch_multiplier() + self.network.apply_to( + text_encoder, + unet, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) - self.network.apply_to( - text_encoder, - unet, - self.train_config.train_text_encoder, - self.train_config.train_unet - ) + self.network.prepare_grad_etc(text_encoder, unet) + flush() - self.network.prepare_grad_etc(text_encoder, unet) - flush() - - params = self.get_params() - - if not params: # LyCORIS doesnt have default_lr config = { 'text_encoder_lr': self.train_config.lr, @@ -639,23 +664,30 @@ class BaseSDTrainProcess(BaseTrainProcess): sig = inspect.signature(self.network.prepare_optimizer_params) if 'default_lr' in sig.parameters: config['default_lr'] = self.train_config.lr - params = self.network.prepare_optimizer_params( + params_net = self.network.prepare_optimizer_params( **config ) - if self.train_config.gradient_checkpointing: - self.network.enable_gradient_checkpointing() + params += params_net - # set the network to normalize if we are - self.network.is_normalizing = self.network_config.normalize + if self.train_config.gradient_checkpointing: + self.network.enable_gradient_checkpointing() - latest_save_path = self.get_latest_save_path() - extra_weights = None - if latest_save_path is not None: - self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") - self.print(f"Loading from {latest_save_path}") - extra_weights = self.load_weights(latest_save_path) - self.network.multiplier = 1.0 + # set the network to normalize if we are + self.network.is_normalizing = self.network_config.normalize + + lora_name = self.name + # need to adapt name so they are not mixed up + if self.adapter_config is not None or self.embedding is not None: + lora_name = f"{lora_name}_LoRA" + + latest_save_path = self.get_latest_save_path(lora_name) + extra_weights = None + if latest_save_path is not None: + self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + self.print(f"Loading from {latest_save_path}") + extra_weights = self.load_weights(latest_save_path) + self.network.multiplier = 1.0 if self.embed_config is not None: # we are doing embedding training as well @@ -672,68 +704,71 @@ class BaseSDTrainProcess(BaseTrainProcess): 'lr': self.train_config.embedding_lr }) - flush() - elif self.embed_config is not None: - self.embedding = Embedding( - sd=self.sd, - embed_config=self.embed_config - ) - latest_save_path = self.get_latest_save_path(self.embed_config.trigger) - # load last saved weights - if latest_save_path is not None: - self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) + flush() - # resume state from embedding - self.step_num = self.embedding.step - self.start_step = self.step_num - - params = self.get_params() - if not params: - # set trainable params - params = self.embedding.get_trainable_params() - - flush() - elif self.adapter_config is not None: - self.adapter = T2IAdapter( - in_channels=self.adapter_config.in_channels, - channels=self.adapter_config.channels, - num_res_blocks=self.adapter_config.num_res_blocks, - downscale_factor=self.adapter_config.downscale_factor, - adapter_type=self.adapter_config.adapter_type, - ) - # t2i adapter - latest_save_path = self.get_latest_save_path(self.embed_config.trigger) - if latest_save_path is not None: - # load adapter from path - print(f"Loading adapter from {latest_save_path}") - loaded_state_dict = load_t2i_model( - latest_save_path, - self.device_torch, - dtype=dtype + if self.embed_config is not None: + self.embedding = Embedding( + sd=self.sd, + embed_config=self.embed_config ) - self.adapter.load_state_dict(loaded_state_dict) - self.load_training_state_from_metadata(latest_save_path) - params = self.get_params() - if not params: + latest_save_path = self.get_latest_save_path(self.embed_config.trigger) + # load last saved weights + if latest_save_path is not None: + self.embedding.load_embedding_from_file(latest_save_path, self.device_torch) + + # resume state from embedding + self.step_num = self.embedding.step + self.start_step = self.step_num + + params = self.get_params() + if not params: + # set trainable params + params = self.embedding.get_trainable_params() + + flush() + + if self.adapter_config is not None: + self.adapter = T2IAdapter( + in_channels=self.adapter_config.in_channels, + channels=self.adapter_config.channels, + num_res_blocks=self.adapter_config.num_res_blocks, + downscale_factor=self.adapter_config.downscale_factor, + adapter_type=self.adapter_config.adapter_type, + ) + self.adapter.to(self.device_torch, dtype=dtype) + # t2i adapter + adapter_name = self.name + if self.network_config is not None: + adapter_name = f"{adapter_name}_t2i" + latest_save_path = self.get_latest_save_path(adapter_name) + if latest_save_path is not None: + # load adapter from path + print(f"Loading adapter from {latest_save_path}") + loaded_state_dict = load_t2i_model( + latest_save_path, + self.device, + dtype=dtype + ) + self.adapter.load_state_dict(loaded_state_dict) + self.load_training_state_from_metadata(latest_save_path) # set trainable params - params = self.adapter.parameters() - self.sd.adapter = self.adapter - flush() - else: + params.append({ + 'params': self.adapter.parameters(), + 'lr': self.train_config.adapter_lr + }) + self.sd.adapter = self.adapter + flush() + else: # no network, embedding or adapter # set the device state preset before getting params self.sd.set_device_state(self.train_device_state_preset) - - params = self.get_params() - - if params is None: - # will only return savable weights and ones with grad - params = self.sd.prepare_optimizer_params( - unet=self.train_config.train_unet, - text_encoder=self.train_config.train_text_encoder, - text_encoder_lr=self.train_config.lr, - unet_lr=self.train_config.lr, - default_lr=self.train_config.lr - ) + # will only return savable weights and ones with grad + params = self.sd.prepare_optimizer_params( + unet=self.train_config.train_unet, + text_encoder=self.train_config.train_text_encoder, + text_encoder_lr=self.train_config.lr, + unet_lr=self.train_config.lr, + default_lr=self.train_config.lr + ) flush() ### HOOK ### params = self.hook_add_extra_train_params(params) @@ -746,7 +781,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.params.append(param) optimizer_type = self.train_config.optimizer.lower() - optimizer = get_optimizer(params, optimizer_type, learning_rate=self.train_config.lr, + optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr, optimizer_params=self.train_config.optimizer_params) self.optimizer = optimizer diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 4c963267..9bf76ab4 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -92,6 +92,7 @@ class TrainConfig: self.unet_lr = kwargs.get('unet_lr', self.lr) self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr) self.embedding_lr = kwargs.get('embedding_lr', self.lr) + self.adapter_lr = kwargs.get('adapter_lr', self.lr) self.optimizer = kwargs.get('optimizer', 'adamw') self.optimizer_params = kwargs.get('optimizer_params', {}) self.lr_scheduler = kwargs.get('lr_scheduler', 'constant') diff --git a/toolkit/saving.py b/toolkit/saving.py index 38c4a1ad..8d172b71 100644 --- a/toolkit/saving.py +++ b/toolkit/saving.py @@ -184,7 +184,7 @@ def save_t2i_from_diffusers( def load_t2i_model( path_to_file, - device: Union[str, torch.device] = 'cpu', + device: Union[str] = 'cpu', dtype: torch.dtype = torch.float32 ): raw_state_dict = load_file(path_to_file, device) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 3fd2310e..7cabe84c 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -250,7 +250,7 @@ class StableDiffusion: # add hacks to unet to help training # pipe.unet = prepare_unet_for_training(pipe.unet) - self.unet = pipe.unet + self.unet: 'UNet2DConditionModel' = pipe.unet self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype) self.vae.eval() self.vae.requires_grad_(False) @@ -360,8 +360,9 @@ class StableDiffusion: extra = {} if gen_config.adapter_image_path is not None: validation_image = Image.open(gen_config.adapter_image_path).convert("RGB") - validation_image = validation_image.resize((gen_config.width, gen_config.height)) + validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) extra['image'] = validation_image + extra['adapter_conditioning_scale'] = 1.0 if self.network is not None: self.network.multiplier = gen_config.network_multiplier @@ -933,7 +934,7 @@ class StableDiffusion: self.device_state['adapter'] = { 'training': self.adapter.training, 'device': self.adapter.device, - 'requires_grad': self.adapter.requires_grad, + 'requires_grad': self.adapter.adapter.conv_in.weight.requires_grad, } def restore_device_state(self):