mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-11 13:39:50 +00:00
Only quantize flux T5 is also quantizing model. Load TE from original name and path if fine tuning.
This commit is contained in:
@@ -394,6 +394,8 @@ class TrainConfig:
|
||||
class ModelConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.name_or_path: str = kwargs.get('name_or_path', None)
|
||||
# name or path is updated on fine tuning. Keep a copy of the original
|
||||
self.name_or_path_original: str = self.name_or_path
|
||||
self.is_v2: bool = kwargs.get('is_v2', False)
|
||||
self.is_xl: bool = kwargs.get('is_xl', False)
|
||||
self.is_pixart: bool = kwargs.get('is_pixart', False)
|
||||
|
||||
@@ -527,7 +527,8 @@ class StableDiffusion:
|
||||
|
||||
elif self.model_config.is_flux:
|
||||
print("Loading Flux model")
|
||||
base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
base_model_path = self.model_config.name_or_path_original
|
||||
print("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
@@ -688,11 +689,12 @@ class StableDiffusion:
|
||||
text_encoder_2.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
print("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
print("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
print("Loading clip")
|
||||
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
@@ -2304,34 +2306,7 @@ class StableDiffusion:
|
||||
named_params[name] = param
|
||||
if unet:
|
||||
if self.is_flux:
|
||||
# Just train the middle 2 blocks of each transformer block
|
||||
# block_list = []
|
||||
# num_transformer_blocks = 2
|
||||
# start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
|
||||
# for i in range(num_transformer_blocks):
|
||||
# block_list.append(self.unet.transformer_blocks[start_block + i])
|
||||
#
|
||||
# num_single_transformer_blocks = 4
|
||||
# start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
|
||||
# for i in range(num_single_transformer_blocks):
|
||||
# block_list.append(self.unet.single_transformer_blocks[start_block + i])
|
||||
#
|
||||
# for block in block_list:
|
||||
# for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
|
||||
# train the guidance embedding
|
||||
# if self.unet.config.guidance_embeds:
|
||||
# transformer: FluxTransformer2DModel = self.unet
|
||||
# for name, param in transformer.time_text_embed.named_parameters(recurse=True,
|
||||
# prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
|
||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
|
||||
prefix="transformer.transformer_blocks"):
|
||||
named_params[name] = param
|
||||
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True,
|
||||
prefix="transformer.single_transformer_blocks"):
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
|
||||
Reference in New Issue
Block a user