Improvements for full tuning flux. Added debugging launch config for vscode

This commit is contained in:
Jaret Burkett
2024-10-29 04:54:08 -06:00
parent 3400882a80
commit 22cd40d7b9
6 changed files with 170 additions and 19 deletions

View File

@@ -24,6 +24,7 @@ from torchvision.transforms import Resize, transforms
from toolkit.assistant_lora import load_assistant_lora_from_path
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae
@@ -660,8 +661,10 @@ class StableDiffusion:
# unfortunately, not an easier way with peft
pipe.unload_lora_weights()
flush()
if self.model_config.quantize:
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
print("Quantizing transformer")
quantize(transformer, weights=quantization_type)
@@ -1404,6 +1407,7 @@ class StableDiffusion:
gen_config.save_image(img, i)
gen_config.log_image(img, i)
flush()
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
self.adapter.clear_memory()
@@ -2324,14 +2328,25 @@ class StableDiffusion:
# named_params[name] = param
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
prefix=f"{SD_PREFIX_UNET}"):
prefix="transformer.transformer_blocks"):
named_params[name] = param
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True,
prefix=f"{SD_PREFIX_UNET}"):
prefix="transformer.single_transformer_blocks"):
named_params[name] = param
else:
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
if self.model_config.ignore_if_contains is not None:
# remove params that contain the ignore_if_contains from named params
for key in list(named_params.keys()):
if any([s in key for s in self.model_config.ignore_if_contains]):
del named_params[key]
if self.model_config.only_if_contains is not None:
# remove params that do not contain the only_if_contains from named params
for key in list(named_params.keys()):
if not any([s in key for s in self.model_config.only_if_contains]):
del named_params[key]
if refiner:
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
@@ -2420,12 +2435,6 @@ class StableDiffusion:
# saving in diffusers format
if not output_file.endswith('.safetensors'):
# diffusers
# if self.is_pixart:
# self.unet.save_pretrained(
# save_directory=output_file,
# safe_serialization=True,
# )
# else:
if self.is_flux:
# only save the unet
transformer: FluxTransformer2DModel = self.unet