mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Improvements for full tuning flux. Added debugging launch config for vscode
This commit is contained in:
28
.vscode/launch.json
vendored
Normal file
28
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Run current config",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/run.py",
|
||||
"args": [
|
||||
"${file}"
|
||||
],
|
||||
"env": {
|
||||
"CUDA_LAUNCH_BLOCKING": "1",
|
||||
"DEBUG_TOOLKIT": "1"
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Python: Debug Current File",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
]
|
||||
}
|
||||
@@ -174,7 +174,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_adapter=is_training_adapter,
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder
|
||||
unload_text_encoder=self.train_config.unload_text_encoder,
|
||||
require_grads=False # we ensure them later
|
||||
)
|
||||
|
||||
self.get_params_device_state_preset = get_train_sd_device_state_preset(
|
||||
device=self.device_torch,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=self.network_config is not None,
|
||||
train_adapter=is_training_adapter,
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder,
|
||||
require_grads=True # We check for grads when getting params
|
||||
)
|
||||
|
||||
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
|
||||
@@ -575,9 +589,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def ensure_params_requires_grad(self):
|
||||
# get param groups
|
||||
for group in self.optimizer.param_groups:
|
||||
# for group in self.optimizer.param_groups:
|
||||
for group in self.params:
|
||||
for param in group['params']:
|
||||
param.requires_grad = True
|
||||
if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter
|
||||
param.requires_grad_(True)
|
||||
|
||||
def setup_ema(self):
|
||||
if self.train_config.ema_config.use_ema:
|
||||
@@ -1487,7 +1503,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
else: # no network, embedding or adapter
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
self.sd.set_device_state(self.get_params_device_state_preset)
|
||||
|
||||
# params = self.get_params()
|
||||
if len(params) == 0:
|
||||
@@ -1521,6 +1537,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.start_step = self.step_num
|
||||
|
||||
optimizer_type = self.train_config.optimizer.lower()
|
||||
|
||||
# esure params require grad
|
||||
self.ensure_params_requires_grad()
|
||||
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
|
||||
optimizer_params=self.train_config.optimizer_params)
|
||||
self.optimizer = optimizer
|
||||
|
||||
@@ -449,7 +449,13 @@ class ModelConfig:
|
||||
self.attn_masking = kwargs.get("attn_masking", False)
|
||||
if self.attn_masking and not self.is_flux:
|
||||
raise ValueError("attn_masking is only supported with flux models currently")
|
||||
pass
|
||||
# for targeting a specific layers
|
||||
self.ignore_if_contains: Optional[List[str]] = kwargs.get("ignore_if_contains", None)
|
||||
self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None)
|
||||
|
||||
if self.ignore_if_contains is not None or self.only_if_contains is not None:
|
||||
if not self.is_flux:
|
||||
raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently")
|
||||
|
||||
|
||||
class EMAConfig:
|
||||
|
||||
88
toolkit/dequantize.py
Normal file
88
toolkit/dequantize.py
Normal file
@@ -0,0 +1,88 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
from optimum.quanto.tensor import QTensor
|
||||
import torch
|
||||
|
||||
|
||||
def hacked_state_dict(self, *args, **kwargs):
|
||||
orig_state_dict = self.orig_state_dict(*args, **kwargs)
|
||||
new_state_dict = {}
|
||||
for key, value in orig_state_dict.items():
|
||||
if key.endswith("._scale"):
|
||||
continue
|
||||
if key.endswith(".input_scale"):
|
||||
continue
|
||||
if key.endswith(".output_scale"):
|
||||
continue
|
||||
if key.endswith("._data"):
|
||||
key = key[:-6]
|
||||
scale = orig_state_dict[key + "._scale"]
|
||||
# scale is the original dtype
|
||||
dtype = scale.dtype
|
||||
scale = scale.float()
|
||||
value = value.float()
|
||||
dequantized = value * scale
|
||||
|
||||
# handle input and output scaling if they exist
|
||||
input_scale = orig_state_dict.get(key + ".input_scale")
|
||||
|
||||
if input_scale is not None:
|
||||
# make sure the tensor is 1.0
|
||||
if input_scale.item() != 1.0:
|
||||
raise ValueError("Input scale is not 1.0, cannot dequantize")
|
||||
|
||||
output_scale = orig_state_dict.get(key + ".output_scale")
|
||||
|
||||
if output_scale is not None:
|
||||
# make sure the tensor is 1.0
|
||||
if output_scale.item() != 1.0:
|
||||
raise ValueError("Output scale is not 1.0, cannot dequantize")
|
||||
|
||||
new_state_dict[key] = dequantized.to('cpu', dtype=dtype)
|
||||
else:
|
||||
new_state_dict[key] = value
|
||||
return new_state_dict
|
||||
|
||||
# hacks the state dict so we can dequantize before saving
|
||||
def patch_dequantization_on_save(model):
|
||||
model.orig_state_dict = model.state_dict
|
||||
model.state_dict = partial(hacked_state_dict, model)
|
||||
|
||||
|
||||
def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool:
|
||||
"""
|
||||
Convert a quantized parameter back to a regular Parameter with floating point values.
|
||||
|
||||
Args:
|
||||
module: The module containing the parameter to unquantize
|
||||
param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias')
|
||||
|
||||
Returns:
|
||||
bool: True if parameter was unquantized, False if it was already unquantized
|
||||
"""
|
||||
|
||||
# Check if the parameter exists
|
||||
if not hasattr(module, param_name):
|
||||
raise AttributeError(f"Module has no parameter named '{param_name}'")
|
||||
|
||||
param = getattr(module, param_name)
|
||||
|
||||
# If it's not a parameter or not quantized, nothing to do
|
||||
if not isinstance(param, torch.nn.Parameter):
|
||||
raise TypeError(f"'{param_name}' is not a Parameter")
|
||||
if not isinstance(param, QTensor):
|
||||
return False
|
||||
|
||||
# Convert to float tensor while preserving device and requires_grad
|
||||
with torch.no_grad():
|
||||
float_tensor = param.float()
|
||||
new_param = torch.nn.Parameter(
|
||||
float_tensor,
|
||||
requires_grad=param.requires_grad
|
||||
)
|
||||
|
||||
# Replace the parameter
|
||||
setattr(module, param_name, new_param)
|
||||
|
||||
return True
|
||||
@@ -41,6 +41,7 @@ def get_train_sd_device_state_preset(
|
||||
train_embedding: bool = False,
|
||||
train_refiner: bool = False,
|
||||
unload_text_encoder: bool = False,
|
||||
require_grads: bool = True,
|
||||
):
|
||||
preset = copy.deepcopy(empty_preset)
|
||||
if not cached_latents:
|
||||
@@ -48,27 +49,27 @@ def get_train_sd_device_state_preset(
|
||||
|
||||
if train_unet:
|
||||
preset['unet']['training'] = True
|
||||
preset['unet']['requires_grad'] = True
|
||||
preset['unet']['requires_grad'] = require_grads
|
||||
preset['unet']['device'] = device
|
||||
else:
|
||||
preset['unet']['device'] = device
|
||||
|
||||
if train_text_encoder:
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['text_encoder']['requires_grad'] = True
|
||||
preset['text_encoder']['requires_grad'] = require_grads
|
||||
preset['text_encoder']['device'] = device
|
||||
else:
|
||||
preset['text_encoder']['device'] = device
|
||||
|
||||
if train_embedding:
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['text_encoder']['requires_grad'] = True
|
||||
preset['text_encoder']['requires_grad'] = require_grads
|
||||
preset['text_encoder']['training'] = True
|
||||
preset['unet']['training'] = True
|
||||
|
||||
if train_refiner:
|
||||
preset['refiner_unet']['training'] = True
|
||||
preset['refiner_unet']['requires_grad'] = True
|
||||
preset['refiner_unet']['requires_grad'] = require_grads
|
||||
preset['refiner_unet']['device'] = device
|
||||
# if not training unet, move that to cpu
|
||||
if not train_unet:
|
||||
@@ -81,7 +82,7 @@ def get_train_sd_device_state_preset(
|
||||
preset['refiner_unet']['requires_grad'] = False
|
||||
|
||||
if train_adapter:
|
||||
preset['adapter']['requires_grad'] = True
|
||||
preset['adapter']['requires_grad'] = require_grads
|
||||
preset['adapter']['training'] = True
|
||||
preset['adapter']['device'] = device
|
||||
preset['unet']['training'] = True
|
||||
|
||||
@@ -24,6 +24,7 @@ from torchvision.transforms import Resize, transforms
|
||||
from toolkit.assistant_lora import load_assistant_lora_from_path
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
@@ -660,8 +661,10 @@ class StableDiffusion:
|
||||
# unfortunately, not an easier way with peft
|
||||
pipe.unload_lora_weights()
|
||||
flush()
|
||||
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
print("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type)
|
||||
@@ -1404,6 +1407,7 @@ class StableDiffusion:
|
||||
|
||||
gen_config.save_image(img, i)
|
||||
gen_config.log_image(img, i)
|
||||
flush()
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
|
||||
self.adapter.clear_memory()
|
||||
@@ -2324,14 +2328,25 @@ class StableDiffusion:
|
||||
# named_params[name] = param
|
||||
|
||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
|
||||
prefix=f"{SD_PREFIX_UNET}"):
|
||||
prefix="transformer.transformer_blocks"):
|
||||
named_params[name] = param
|
||||
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True,
|
||||
prefix=f"{SD_PREFIX_UNET}"):
|
||||
prefix="transformer.single_transformer_blocks"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
if self.model_config.ignore_if_contains is not None:
|
||||
# remove params that contain the ignore_if_contains from named params
|
||||
for key in list(named_params.keys()):
|
||||
if any([s in key for s in self.model_config.ignore_if_contains]):
|
||||
del named_params[key]
|
||||
if self.model_config.only_if_contains is not None:
|
||||
# remove params that do not contain the only_if_contains from named params
|
||||
for key in list(named_params.keys()):
|
||||
if not any([s in key for s in self.model_config.only_if_contains]):
|
||||
del named_params[key]
|
||||
|
||||
if refiner:
|
||||
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
|
||||
@@ -2420,12 +2435,6 @@ class StableDiffusion:
|
||||
# saving in diffusers format
|
||||
if not output_file.endswith('.safetensors'):
|
||||
# diffusers
|
||||
# if self.is_pixart:
|
||||
# self.unet.save_pretrained(
|
||||
# save_directory=output_file,
|
||||
# safe_serialization=True,
|
||||
# )
|
||||
# else:
|
||||
if self.is_flux:
|
||||
# only save the unet
|
||||
transformer: FluxTransformer2DModel = self.unet
|
||||
|
||||
Reference in New Issue
Block a user