mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Improvements for full tuning flux. Added debugging launch config for vscode
This commit is contained in:
@@ -24,6 +24,7 @@ from torchvision.transforms import Resize, transforms
|
||||
from toolkit.assistant_lora import load_assistant_lora_from_path
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
@@ -660,8 +661,10 @@ class StableDiffusion:
|
||||
# unfortunately, not an easier way with peft
|
||||
pipe.unload_lora_weights()
|
||||
flush()
|
||||
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
print("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type)
|
||||
@@ -1404,6 +1407,7 @@ class StableDiffusion:
|
||||
|
||||
gen_config.save_image(img, i)
|
||||
gen_config.log_image(img, i)
|
||||
flush()
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
|
||||
self.adapter.clear_memory()
|
||||
@@ -2324,14 +2328,25 @@ class StableDiffusion:
|
||||
# named_params[name] = param
|
||||
|
||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
|
||||
prefix=f"{SD_PREFIX_UNET}"):
|
||||
prefix="transformer.transformer_blocks"):
|
||||
named_params[name] = param
|
||||
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True,
|
||||
prefix=f"{SD_PREFIX_UNET}"):
|
||||
prefix="transformer.single_transformer_blocks"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
if self.model_config.ignore_if_contains is not None:
|
||||
# remove params that contain the ignore_if_contains from named params
|
||||
for key in list(named_params.keys()):
|
||||
if any([s in key for s in self.model_config.ignore_if_contains]):
|
||||
del named_params[key]
|
||||
if self.model_config.only_if_contains is not None:
|
||||
# remove params that do not contain the only_if_contains from named params
|
||||
for key in list(named_params.keys()):
|
||||
if not any([s in key for s in self.model_config.only_if_contains]):
|
||||
del named_params[key]
|
||||
|
||||
if refiner:
|
||||
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
|
||||
@@ -2420,12 +2435,6 @@ class StableDiffusion:
|
||||
# saving in diffusers format
|
||||
if not output_file.endswith('.safetensors'):
|
||||
# diffusers
|
||||
# if self.is_pixart:
|
||||
# self.unet.save_pretrained(
|
||||
# save_directory=output_file,
|
||||
# safe_serialization=True,
|
||||
# )
|
||||
# else:
|
||||
if self.is_flux:
|
||||
# only save the unet
|
||||
transformer: FluxTransformer2DModel = self.unet
|
||||
|
||||
Reference in New Issue
Block a user