Improvements for full tuning flux. Added debugging launch config for vscode

This commit is contained in:
Jaret Burkett
2024-10-29 04:54:08 -06:00
parent 3400882a80
commit 22cd40d7b9
6 changed files with 170 additions and 19 deletions

28
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,28 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Run current config",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/run.py",
"args": [
"${file}"
],
"env": {
"CUDA_LAUNCH_BLOCKING": "1",
"DEBUG_TOOLKIT": "1"
},
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Python: Debug Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false
},
]
}

View File

@@ -174,7 +174,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_adapter=is_training_adapter,
train_embedding=self.embed_config is not None,
train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder
unload_text_encoder=self.train_config.unload_text_encoder,
require_grads=False # we ensure them later
)
self.get_params_device_state_preset = get_train_sd_device_state_preset(
device=self.device_torch,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
cached_latents=self.is_latents_cached,
train_lora=self.network_config is not None,
train_adapter=is_training_adapter,
train_embedding=self.embed_config is not None,
train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder,
require_grads=True # We check for grads when getting params
)
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
@@ -575,9 +589,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
def ensure_params_requires_grad(self):
# get param groups
for group in self.optimizer.param_groups:
# for group in self.optimizer.param_groups:
for group in self.params:
for param in group['params']:
param.requires_grad = True
if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter
param.requires_grad_(True)
def setup_ema(self):
if self.train_config.ema_config.use_ema:
@@ -1487,7 +1503,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
else: # no network, embedding or adapter
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)
self.sd.set_device_state(self.get_params_device_state_preset)
# params = self.get_params()
if len(params) == 0:
@@ -1521,6 +1537,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.start_step = self.step_num
optimizer_type = self.train_config.optimizer.lower()
# esure params require grad
self.ensure_params_requires_grad()
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
optimizer_params=self.train_config.optimizer_params)
self.optimizer = optimizer

View File

@@ -449,7 +449,13 @@ class ModelConfig:
self.attn_masking = kwargs.get("attn_masking", False)
if self.attn_masking and not self.is_flux:
raise ValueError("attn_masking is only supported with flux models currently")
pass
# for targeting a specific layers
self.ignore_if_contains: Optional[List[str]] = kwargs.get("ignore_if_contains", None)
self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None)
if self.ignore_if_contains is not None or self.only_if_contains is not None:
if not self.is_flux:
raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently")
class EMAConfig:

88
toolkit/dequantize.py Normal file
View File

@@ -0,0 +1,88 @@
from functools import partial
from optimum.quanto.tensor import QTensor
import torch
def hacked_state_dict(self, *args, **kwargs):
orig_state_dict = self.orig_state_dict(*args, **kwargs)
new_state_dict = {}
for key, value in orig_state_dict.items():
if key.endswith("._scale"):
continue
if key.endswith(".input_scale"):
continue
if key.endswith(".output_scale"):
continue
if key.endswith("._data"):
key = key[:-6]
scale = orig_state_dict[key + "._scale"]
# scale is the original dtype
dtype = scale.dtype
scale = scale.float()
value = value.float()
dequantized = value * scale
# handle input and output scaling if they exist
input_scale = orig_state_dict.get(key + ".input_scale")
if input_scale is not None:
# make sure the tensor is 1.0
if input_scale.item() != 1.0:
raise ValueError("Input scale is not 1.0, cannot dequantize")
output_scale = orig_state_dict.get(key + ".output_scale")
if output_scale is not None:
# make sure the tensor is 1.0
if output_scale.item() != 1.0:
raise ValueError("Output scale is not 1.0, cannot dequantize")
new_state_dict[key] = dequantized.to('cpu', dtype=dtype)
else:
new_state_dict[key] = value
return new_state_dict
# hacks the state dict so we can dequantize before saving
def patch_dequantization_on_save(model):
model.orig_state_dict = model.state_dict
model.state_dict = partial(hacked_state_dict, model)
def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool:
"""
Convert a quantized parameter back to a regular Parameter with floating point values.
Args:
module: The module containing the parameter to unquantize
param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias')
Returns:
bool: True if parameter was unquantized, False if it was already unquantized
"""
# Check if the parameter exists
if not hasattr(module, param_name):
raise AttributeError(f"Module has no parameter named '{param_name}'")
param = getattr(module, param_name)
# If it's not a parameter or not quantized, nothing to do
if not isinstance(param, torch.nn.Parameter):
raise TypeError(f"'{param_name}' is not a Parameter")
if not isinstance(param, QTensor):
return False
# Convert to float tensor while preserving device and requires_grad
with torch.no_grad():
float_tensor = param.float()
new_param = torch.nn.Parameter(
float_tensor,
requires_grad=param.requires_grad
)
# Replace the parameter
setattr(module, param_name, new_param)
return True

View File

@@ -41,6 +41,7 @@ def get_train_sd_device_state_preset(
train_embedding: bool = False,
train_refiner: bool = False,
unload_text_encoder: bool = False,
require_grads: bool = True,
):
preset = copy.deepcopy(empty_preset)
if not cached_latents:
@@ -48,27 +49,27 @@ def get_train_sd_device_state_preset(
if train_unet:
preset['unet']['training'] = True
preset['unet']['requires_grad'] = True
preset['unet']['requires_grad'] = require_grads
preset['unet']['device'] = device
else:
preset['unet']['device'] = device
if train_text_encoder:
preset['text_encoder']['training'] = True
preset['text_encoder']['requires_grad'] = True
preset['text_encoder']['requires_grad'] = require_grads
preset['text_encoder']['device'] = device
else:
preset['text_encoder']['device'] = device
if train_embedding:
preset['text_encoder']['training'] = True
preset['text_encoder']['requires_grad'] = True
preset['text_encoder']['requires_grad'] = require_grads
preset['text_encoder']['training'] = True
preset['unet']['training'] = True
if train_refiner:
preset['refiner_unet']['training'] = True
preset['refiner_unet']['requires_grad'] = True
preset['refiner_unet']['requires_grad'] = require_grads
preset['refiner_unet']['device'] = device
# if not training unet, move that to cpu
if not train_unet:
@@ -81,7 +82,7 @@ def get_train_sd_device_state_preset(
preset['refiner_unet']['requires_grad'] = False
if train_adapter:
preset['adapter']['requires_grad'] = True
preset['adapter']['requires_grad'] = require_grads
preset['adapter']['training'] = True
preset['adapter']['device'] = device
preset['unet']['training'] = True

View File

@@ -24,6 +24,7 @@ from torchvision.transforms import Resize, transforms
from toolkit.assistant_lora import load_assistant_lora_from_path
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae
@@ -660,8 +661,10 @@ class StableDiffusion:
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
flush()
if self.model_config.quantize:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
print("Quantizing transformer")
quantize(transformer, weights=quantization_type)
@@ -1404,6 +1407,7 @@ class StableDiffusion:
gen_config.save_image(img, i)
gen_config.log_image(img, i)
flush()
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
self.adapter.clear_memory()
@@ -2324,14 +2328,25 @@ class StableDiffusion:
# named_params[name] = param
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
prefix=f"{SD_PREFIX_UNET}"):
prefix="transformer.transformer_blocks"):
named_params[name] = param
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True,
prefix=f"{SD_PREFIX_UNET}"):
prefix="transformer.single_transformer_blocks"):
named_params[name] = param
else:
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
if self.model_config.ignore_if_contains is not None:
# remove params that contain the ignore_if_contains from named params
for key in list(named_params.keys()):
if any([s in key for s in self.model_config.ignore_if_contains]):
del named_params[key]
if self.model_config.only_if_contains is not None:
# remove params that do not contain the only_if_contains from named params
for key in list(named_params.keys()):
if not any([s in key for s in self.model_config.only_if_contains]):
del named_params[key]
if refiner:
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
@@ -2420,12 +2435,6 @@ class StableDiffusion:
# saving in diffusers format
if not output_file.endswith('.safetensors'):
# diffusers
# if self.is_pixart:
# self.unet.save_pretrained(
# save_directory=output_file,
# safe_serialization=True,
# )
# else:
if self.is_flux:
# only save the unet
transformer: FluxTransformer2DModel = self.unet