From 22cd40d7b95ff9b3a01be669c0da88960d045017 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 29 Oct 2024 04:54:08 -0600 Subject: [PATCH] Improvements for full tuning flux. Added debugging launch config for vscode --- .vscode/launch.json | 28 +++++++++ jobs/process/BaseSDTrainProcess.py | 27 +++++++-- toolkit/config_modules.py | 8 ++- toolkit/dequantize.py | 88 +++++++++++++++++++++++++++++ toolkit/sd_device_states_presets.py | 11 ++-- toolkit/stable_diffusion_model.py | 27 ++++++--- 6 files changed, 170 insertions(+), 19 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 toolkit/dequantize.py diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..483703eb --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,28 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Run current config", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/run.py", + "args": [ + "${file}" + ], + "env": { + "CUDA_LAUNCH_BLOCKING": "1", + "DEBUG_TOOLKIT": "1" + }, + "console": "integratedTerminal", + "justMyCode": false + }, + { + "name": "Python: Debug Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": false + }, + ] +} \ No newline at end of file diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e146cc2d..76e49b04 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -174,7 +174,21 @@ class BaseSDTrainProcess(BaseTrainProcess): train_adapter=is_training_adapter, train_embedding=self.embed_config is not None, train_refiner=self.train_config.train_refiner, - unload_text_encoder=self.train_config.unload_text_encoder + unload_text_encoder=self.train_config.unload_text_encoder, + require_grads=False # we ensure them later + ) + + self.get_params_device_state_preset = get_train_sd_device_state_preset( + device=self.device_torch, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + cached_latents=self.is_latents_cached, + train_lora=self.network_config is not None, + train_adapter=is_training_adapter, + train_embedding=self.embed_config is not None, + train_refiner=self.train_config.train_refiner, + unload_text_encoder=self.train_config.unload_text_encoder, + require_grads=True # We check for grads when getting params ) # fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc) @@ -575,9 +589,11 @@ class BaseSDTrainProcess(BaseTrainProcess): def ensure_params_requires_grad(self): # get param groups - for group in self.optimizer.param_groups: + # for group in self.optimizer.param_groups: + for group in self.params: for param in group['params']: - param.requires_grad = True + if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter + param.requires_grad_(True) def setup_ema(self): if self.train_config.ema_config.use_ema: @@ -1487,7 +1503,7 @@ class BaseSDTrainProcess(BaseTrainProcess): else: # no network, embedding or adapter # set the device state preset before getting params - self.sd.set_device_state(self.train_device_state_preset) + self.sd.set_device_state(self.get_params_device_state_preset) # params = self.get_params() if len(params) == 0: @@ -1521,6 +1537,9 @@ class BaseSDTrainProcess(BaseTrainProcess): self.start_step = self.step_num optimizer_type = self.train_config.optimizer.lower() + + # esure params require grad + self.ensure_params_requires_grad() optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr, optimizer_params=self.train_config.optimizer_params) self.optimizer = optimizer diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 51bb57e7..a345bf43 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -449,7 +449,13 @@ class ModelConfig: self.attn_masking = kwargs.get("attn_masking", False) if self.attn_masking and not self.is_flux: raise ValueError("attn_masking is only supported with flux models currently") - pass + # for targeting a specific layers + self.ignore_if_contains: Optional[List[str]] = kwargs.get("ignore_if_contains", None) + self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None) + + if self.ignore_if_contains is not None or self.only_if_contains is not None: + if not self.is_flux: + raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently") class EMAConfig: diff --git a/toolkit/dequantize.py b/toolkit/dequantize.py new file mode 100644 index 00000000..54c8ec7b --- /dev/null +++ b/toolkit/dequantize.py @@ -0,0 +1,88 @@ + + +from functools import partial +from optimum.quanto.tensor import QTensor +import torch + + +def hacked_state_dict(self, *args, **kwargs): + orig_state_dict = self.orig_state_dict(*args, **kwargs) + new_state_dict = {} + for key, value in orig_state_dict.items(): + if key.endswith("._scale"): + continue + if key.endswith(".input_scale"): + continue + if key.endswith(".output_scale"): + continue + if key.endswith("._data"): + key = key[:-6] + scale = orig_state_dict[key + "._scale"] + # scale is the original dtype + dtype = scale.dtype + scale = scale.float() + value = value.float() + dequantized = value * scale + + # handle input and output scaling if they exist + input_scale = orig_state_dict.get(key + ".input_scale") + + if input_scale is not None: + # make sure the tensor is 1.0 + if input_scale.item() != 1.0: + raise ValueError("Input scale is not 1.0, cannot dequantize") + + output_scale = orig_state_dict.get(key + ".output_scale") + + if output_scale is not None: + # make sure the tensor is 1.0 + if output_scale.item() != 1.0: + raise ValueError("Output scale is not 1.0, cannot dequantize") + + new_state_dict[key] = dequantized.to('cpu', dtype=dtype) + else: + new_state_dict[key] = value + return new_state_dict + +# hacks the state dict so we can dequantize before saving +def patch_dequantization_on_save(model): + model.orig_state_dict = model.state_dict + model.state_dict = partial(hacked_state_dict, model) + + +def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool: + """ + Convert a quantized parameter back to a regular Parameter with floating point values. + + Args: + module: The module containing the parameter to unquantize + param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias') + + Returns: + bool: True if parameter was unquantized, False if it was already unquantized + """ + + # Check if the parameter exists + if not hasattr(module, param_name): + raise AttributeError(f"Module has no parameter named '{param_name}'") + + param = getattr(module, param_name) + + # If it's not a parameter or not quantized, nothing to do + if not isinstance(param, torch.nn.Parameter): + raise TypeError(f"'{param_name}' is not a Parameter") + if not isinstance(param, QTensor): + return False + + # Convert to float tensor while preserving device and requires_grad + with torch.no_grad(): + float_tensor = param.float() + new_param = torch.nn.Parameter( + float_tensor, + requires_grad=param.requires_grad + ) + + # Replace the parameter + setattr(module, param_name, new_param) + + return True \ No newline at end of file diff --git a/toolkit/sd_device_states_presets.py b/toolkit/sd_device_states_presets.py index 2ee1d555..0a8918ef 100644 --- a/toolkit/sd_device_states_presets.py +++ b/toolkit/sd_device_states_presets.py @@ -41,6 +41,7 @@ def get_train_sd_device_state_preset( train_embedding: bool = False, train_refiner: bool = False, unload_text_encoder: bool = False, + require_grads: bool = True, ): preset = copy.deepcopy(empty_preset) if not cached_latents: @@ -48,27 +49,27 @@ def get_train_sd_device_state_preset( if train_unet: preset['unet']['training'] = True - preset['unet']['requires_grad'] = True + preset['unet']['requires_grad'] = require_grads preset['unet']['device'] = device else: preset['unet']['device'] = device if train_text_encoder: preset['text_encoder']['training'] = True - preset['text_encoder']['requires_grad'] = True + preset['text_encoder']['requires_grad'] = require_grads preset['text_encoder']['device'] = device else: preset['text_encoder']['device'] = device if train_embedding: preset['text_encoder']['training'] = True - preset['text_encoder']['requires_grad'] = True + preset['text_encoder']['requires_grad'] = require_grads preset['text_encoder']['training'] = True preset['unet']['training'] = True if train_refiner: preset['refiner_unet']['training'] = True - preset['refiner_unet']['requires_grad'] = True + preset['refiner_unet']['requires_grad'] = require_grads preset['refiner_unet']['device'] = device # if not training unet, move that to cpu if not train_unet: @@ -81,7 +82,7 @@ def get_train_sd_device_state_preset( preset['refiner_unet']['requires_grad'] = False if train_adapter: - preset['adapter']['requires_grad'] = True + preset['adapter']['requires_grad'] = require_grads preset['adapter']['training'] = True preset['adapter']['device'] = device preset['unet']['training'] = True diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index e85dc67e..9912d09c 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -24,6 +24,7 @@ from torchvision.transforms import Resize, transforms from toolkit.assistant_lora import load_assistant_lora_from_path from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.custom_adapter import CustomAdapter +from toolkit.dequantize import patch_dequantization_on_save from toolkit.ip_adapter import IPAdapter from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ convert_vae_state_dict, load_vae @@ -660,8 +661,10 @@ class StableDiffusion: # unfortunately, not an easier way with peft pipe.unload_lora_weights() flush() - + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) quantization_type = qfloat8 print("Quantizing transformer") quantize(transformer, weights=quantization_type) @@ -1404,6 +1407,7 @@ class StableDiffusion: gen_config.save_image(img, i) gen_config.log_image(img, i) + flush() if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): self.adapter.clear_memory() @@ -2324,14 +2328,25 @@ class StableDiffusion: # named_params[name] = param for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, - prefix=f"{SD_PREFIX_UNET}"): + prefix="transformer.transformer_blocks"): named_params[name] = param for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, - prefix=f"{SD_PREFIX_UNET}"): + prefix="transformer.single_transformer_blocks"): named_params[name] = param else: for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): named_params[name] = param + + if self.model_config.ignore_if_contains is not None: + # remove params that contain the ignore_if_contains from named params + for key in list(named_params.keys()): + if any([s in key for s in self.model_config.ignore_if_contains]): + del named_params[key] + if self.model_config.only_if_contains is not None: + # remove params that do not contain the only_if_contains from named params + for key in list(named_params.keys()): + if not any([s in key for s in self.model_config.only_if_contains]): + del named_params[key] if refiner: for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): @@ -2420,12 +2435,6 @@ class StableDiffusion: # saving in diffusers format if not output_file.endswith('.safetensors'): # diffusers - # if self.is_pixart: - # self.unet.save_pretrained( - # save_directory=output_file, - # safe_serialization=True, - # ) - # else: if self.is_flux: # only save the unet transformer: FluxTransformer2DModel = self.unet