mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Only train a few blocks on flux (for now)
This commit is contained in:
@@ -448,13 +448,19 @@ class StableDiffusion:
|
|||||||
|
|
||||||
elif self.model_config.is_flux:
|
elif self.model_config.is_flux:
|
||||||
print("Loading Flux model")
|
print("Loading Flux model")
|
||||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
|
base_model_path = "/home/jaret/Dev/models/hf/FLUX.1-schnell/"
|
||||||
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||||
print("Loading vae")
|
print("Loading vae")
|
||||||
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
|
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
print("Loading transformer")
|
print("Loading transformer")
|
||||||
|
subfolder = 'transformer'
|
||||||
|
transformer_path = model_path
|
||||||
|
if os.path.exists(transformer_path):
|
||||||
|
subfolder = None
|
||||||
|
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||||
|
|
||||||
transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
|
transformer = FluxTransformer2DModel.from_pretrained(transformer_path, subfolder=subfolder, torch_dtype=dtype)
|
||||||
transformer.to(self.device_torch, dtype=dtype)
|
transformer.to(self.device_torch, dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
@@ -465,20 +471,20 @@ class StableDiffusion:
|
|||||||
flush()
|
flush()
|
||||||
|
|
||||||
print("Loading t5")
|
print("Loading t5")
|
||||||
text_encoder_2 = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||||
tokenizer_2 = T5TokenizerFast.from_pretrained(model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||||
|
|
||||||
text_encoder_2.to(self.device_torch, dtype=dtype)
|
text_encoder_2.to(self.device_torch, dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
if self.model_config.quantize:
|
print("Quantizing T5")
|
||||||
print("Quantizing T5")
|
quantize(text_encoder_2, weights=qfloat8)
|
||||||
quantize(text_encoder_2, weights=qfloat8)
|
freeze(text_encoder_2)
|
||||||
freeze(text_encoder_2)
|
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
print("Loading clip")
|
print("Loading clip")
|
||||||
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
|
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype)
|
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||||
text_encoder.to(self.device_torch, dtype=dtype)
|
text_encoder.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
print("making pipe")
|
print("making pipe")
|
||||||
@@ -772,7 +778,7 @@ class StableDiffusion:
|
|||||||
tokenizer_2=self.tokenizer[1],
|
tokenizer_2=self.tokenizer[1],
|
||||||
scheduler=noise_scheduler,
|
scheduler=noise_scheduler,
|
||||||
**extra_args
|
**extra_args
|
||||||
).to(self.device_torch)
|
)
|
||||||
pipeline.watermark = None
|
pipeline.watermark = None
|
||||||
elif self.is_v3:
|
elif self.is_v3:
|
||||||
pipeline = Pipe(
|
pipeline = Pipe(
|
||||||
@@ -1910,7 +1916,7 @@ class StableDiffusion:
|
|||||||
else:
|
else:
|
||||||
latents = self.vae.encode(images).latent_dist.sample()
|
latents = self.vae.encode(images).latent_dist.sample()
|
||||||
# latents = self.vae.encode(images, return_dict=False)[0]
|
# latents = self.vae.encode(images, return_dict=False)[0]
|
||||||
latents = latents * self.vae.config['scaling_factor']
|
latents = latents * (self.vae.config['scaling_factor'] - self.vae.config['shift_factor'])
|
||||||
latents = latents.to(device, dtype=dtype)
|
latents = latents.to(device, dtype=dtype)
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
@@ -1930,7 +1936,7 @@ class StableDiffusion:
|
|||||||
if self.vae.device == 'cpu':
|
if self.vae.device == 'cpu':
|
||||||
self.vae.to(self.device)
|
self.vae.to(self.device)
|
||||||
latents = latents.to(device, dtype=dtype)
|
latents = latents.to(device, dtype=dtype)
|
||||||
latents = latents / self.vae.config['scaling_factor']
|
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
|
||||||
images = self.vae.decode(latents).sample
|
images = self.vae.decode(latents).sample
|
||||||
images = images.to(device, dtype=dtype)
|
images = images.to(device, dtype=dtype)
|
||||||
|
|
||||||
@@ -2031,8 +2037,30 @@ class StableDiffusion:
|
|||||||
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
|
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
|
||||||
named_params[name] = param
|
named_params[name] = param
|
||||||
if unet:
|
if unet:
|
||||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
if self.is_flux:
|
||||||
named_params[name] = param
|
# Just train the middle 2 blocks of each transformer block
|
||||||
|
block_list = []
|
||||||
|
num_transformer_blocks = 2
|
||||||
|
start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
|
||||||
|
for i in range(num_transformer_blocks):
|
||||||
|
block_list.append(self.unet.transformer_blocks[start_block + i])
|
||||||
|
|
||||||
|
num_single_transformer_blocks = 4
|
||||||
|
start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
|
||||||
|
for i in range(num_single_transformer_blocks):
|
||||||
|
block_list.append(self.unet.single_transformer_blocks[start_block + i])
|
||||||
|
|
||||||
|
for block in block_list:
|
||||||
|
for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||||
|
named_params[name] = param
|
||||||
|
|
||||||
|
# for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||||
|
# named_params[name] = param
|
||||||
|
# for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||||
|
# named_params[name] = param
|
||||||
|
else:
|
||||||
|
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||||
|
named_params[name] = param
|
||||||
|
|
||||||
if refiner:
|
if refiner:
|
||||||
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
|
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
|
||||||
@@ -2127,10 +2155,19 @@ class StableDiffusion:
|
|||||||
# safe_serialization=True,
|
# safe_serialization=True,
|
||||||
# )
|
# )
|
||||||
# else:
|
# else:
|
||||||
self.pipeline.save_pretrained(
|
if self.is_flux:
|
||||||
save_directory=output_file,
|
# only save the unet
|
||||||
safe_serialization=True,
|
transformer: FluxTransformer2DModel = self.unet
|
||||||
)
|
transformer.save_pretrained(
|
||||||
|
save_directory=os.path.join(output_file, 'transformer'),
|
||||||
|
safe_serialization=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
self.pipeline.save_pretrained(
|
||||||
|
save_directory=output_file,
|
||||||
|
safe_serialization=True,
|
||||||
|
)
|
||||||
# save out meta config
|
# save out meta config
|
||||||
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
||||||
with open(meta_path, 'w') as f:
|
with open(meta_path, 'w') as f:
|
||||||
|
|||||||
Reference in New Issue
Block a user