Only train a few blocks on flux (for now)

This commit is contained in:
Jaret Burkett
2024-08-03 07:02:27 -06:00
parent 87ba867fdc
commit 369aa143bc

View File

@@ -448,13 +448,19 @@ class StableDiffusion:
elif self.model_config.is_flux:
print("Loading Flux model")
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
base_model_path = "/home/jaret/Dev/models/hf/FLUX.1-schnell/"
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
print("Loading vae")
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
print("Loading transformer")
subfolder = 'transformer'
transformer_path = model_path
if os.path.exists(transformer_path):
subfolder = None
transformer_path = os.path.join(transformer_path, 'transformer')
transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
transformer = FluxTransformer2DModel.from_pretrained(transformer_path, subfolder=subfolder, torch_dtype=dtype)
transformer.to(self.device_torch, dtype=dtype)
flush()
@@ -465,20 +471,20 @@ class StableDiffusion:
flush()
print("Loading t5")
text_encoder_2 = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer_2 = T5TokenizerFast.from_pretrained(model_path, subfolder="tokenizer_2", torch_dtype=dtype)
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
text_encoder_2.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize:
print("Quantizing T5")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
print("Quantizing T5")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
flush()
print("Loading clip")
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
print("making pipe")
@@ -772,7 +778,7 @@ class StableDiffusion:
tokenizer_2=self.tokenizer[1],
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)
)
pipeline.watermark = None
elif self.is_v3:
pipeline = Pipe(
@@ -1910,7 +1916,7 @@ class StableDiffusion:
else:
latents = self.vae.encode(images).latent_dist.sample()
# latents = self.vae.encode(images, return_dict=False)[0]
latents = latents * self.vae.config['scaling_factor']
latents = latents * (self.vae.config['scaling_factor'] - self.vae.config['shift_factor'])
latents = latents.to(device, dtype=dtype)
return latents
@@ -1930,7 +1936,7 @@ class StableDiffusion:
if self.vae.device == 'cpu':
self.vae.to(self.device)
latents = latents.to(device, dtype=dtype)
latents = latents / self.vae.config['scaling_factor']
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
images = self.vae.decode(latents).sample
images = images.to(device, dtype=dtype)
@@ -2031,8 +2037,30 @@ class StableDiffusion:
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
named_params[name] = param
if unet:
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
if self.is_flux:
# Just train the middle 2 blocks of each transformer block
block_list = []
num_transformer_blocks = 2
start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
for i in range(num_transformer_blocks):
block_list.append(self.unet.transformer_blocks[start_block + i])
num_single_transformer_blocks = 4
start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
for i in range(num_single_transformer_blocks):
block_list.append(self.unet.single_transformer_blocks[start_block + i])
for block in block_list:
for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
# for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param
# for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param
else:
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
if refiner:
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
@@ -2127,10 +2155,19 @@ class StableDiffusion:
# safe_serialization=True,
# )
# else:
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
if self.is_flux:
# only save the unet
transformer: FluxTransformer2DModel = self.unet
transformer.save_pretrained(
save_directory=os.path.join(output_file, 'transformer'),
safe_serialization=True,
)
else:
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
# save out meta config
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
with open(meta_path, 'w') as f: