From 4aa19b5c1ddf30aad233e907fdd45447878a2231 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 29 Oct 2024 14:25:31 -0600 Subject: [PATCH] Only quantize flux T5 is also quantizing model. Load TE from original name and path if fine tuning. --- toolkit/config_modules.py | 2 ++ toolkit/stable_diffusion_model.py | 43 +++++++------------------------ 2 files changed, 11 insertions(+), 34 deletions(-) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a345bf43..5eb1a37c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -394,6 +394,8 @@ class TrainConfig: class ModelConfig: def __init__(self, **kwargs): self.name_or_path: str = kwargs.get('name_or_path', None) + # name or path is updated on fine tuning. Keep a copy of the original + self.name_or_path_original: str = self.name_or_path self.is_v2: bool = kwargs.get('is_v2', False) self.is_xl: bool = kwargs.get('is_xl', False) self.is_pixart: bool = kwargs.get('is_pixart', False) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 9912d09c..db19b365 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -527,7 +527,8 @@ class StableDiffusion: elif self.model_config.is_flux: print("Loading Flux model") - base_model_path = "black-forest-labs/FLUX.1-schnell" + # base_model_path = "black-forest-labs/FLUX.1-schnell" + base_model_path = self.model_config.name_or_path_original print("Loading transformer") subfolder = 'transformer' transformer_path = model_path @@ -688,11 +689,12 @@ class StableDiffusion: text_encoder_2.to(self.device_torch, dtype=dtype) flush() - print("Quantizing T5") - quantize(text_encoder_2, weights=qfloat8) - freeze(text_encoder_2) - flush() - + if self.model_config.quantize: + print("Quantizing T5") + quantize(text_encoder_2, weights=qfloat8) + freeze(text_encoder_2) + flush() + print("Loading clip") text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) @@ -2304,34 +2306,7 @@ class StableDiffusion: named_params[name] = param if unet: if self.is_flux: - # Just train the middle 2 blocks of each transformer block - # block_list = [] - # num_transformer_blocks = 2 - # start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2) - # for i in range(num_transformer_blocks): - # block_list.append(self.unet.transformer_blocks[start_block + i]) - # - # num_single_transformer_blocks = 4 - # start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2) - # for i in range(num_single_transformer_blocks): - # block_list.append(self.unet.single_transformer_blocks[start_block + i]) - # - # for block in block_list: - # for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): - # named_params[name] = param - - # train the guidance embedding - # if self.unet.config.guidance_embeds: - # transformer: FluxTransformer2DModel = self.unet - # for name, param in transformer.time_text_embed.named_parameters(recurse=True, - # prefix=f"{SD_PREFIX_UNET}"): - # named_params[name] = param - - for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, - prefix="transformer.transformer_blocks"): - named_params[name] = param - for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, - prefix="transformer.single_transformer_blocks"): + for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"): named_params[name] = param else: for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):