mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Only train a few blocks on flux (for now)
This commit is contained in:
@@ -448,13 +448,19 @@ class StableDiffusion:
|
||||
|
||||
elif self.model_config.is_flux:
|
||||
print("Loading Flux model")
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
base_model_path = "/home/jaret/Dev/models/hf/FLUX.1-schnell/"
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
print("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
if os.path.exists(transformer_path):
|
||||
subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
|
||||
transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
|
||||
transformer = FluxTransformer2DModel.from_pretrained(transformer_path, subfolder=subfolder, torch_dtype=dtype)
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
@@ -465,20 +471,20 @@ class StableDiffusion:
|
||||
flush()
|
||||
|
||||
print("Loading t5")
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
|
||||
text_encoder_2.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
print("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
print("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
flush()
|
||||
|
||||
print("Loading clip")
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
|
||||
print("making pipe")
|
||||
@@ -772,7 +778,7 @@ class StableDiffusion:
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
).to(self.device_torch)
|
||||
)
|
||||
pipeline.watermark = None
|
||||
elif self.is_v3:
|
||||
pipeline = Pipe(
|
||||
@@ -1910,7 +1916,7 @@ class StableDiffusion:
|
||||
else:
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
# latents = self.vae.encode(images, return_dict=False)[0]
|
||||
latents = latents * self.vae.config['scaling_factor']
|
||||
latents = latents * (self.vae.config['scaling_factor'] - self.vae.config['shift_factor'])
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
@@ -1930,7 +1936,7 @@ class StableDiffusion:
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(self.device)
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
latents = latents / self.vae.config['scaling_factor']
|
||||
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
|
||||
images = self.vae.decode(latents).sample
|
||||
images = images.to(device, dtype=dtype)
|
||||
|
||||
@@ -2031,8 +2037,30 @@ class StableDiffusion:
|
||||
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
|
||||
named_params[name] = param
|
||||
if unet:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
if self.is_flux:
|
||||
# Just train the middle 2 blocks of each transformer block
|
||||
block_list = []
|
||||
num_transformer_blocks = 2
|
||||
start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
|
||||
for i in range(num_transformer_blocks):
|
||||
block_list.append(self.unet.transformer_blocks[start_block + i])
|
||||
|
||||
num_single_transformer_blocks = 4
|
||||
start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
|
||||
for i in range(num_single_transformer_blocks):
|
||||
block_list.append(self.unet.single_transformer_blocks[start_block + i])
|
||||
|
||||
for block in block_list:
|
||||
for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
# for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
# for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
else:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
if refiner:
|
||||
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
|
||||
@@ -2127,10 +2155,19 @@ class StableDiffusion:
|
||||
# safe_serialization=True,
|
||||
# )
|
||||
# else:
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
if self.is_flux:
|
||||
# only save the unet
|
||||
transformer: FluxTransformer2DModel = self.unet
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
else:
|
||||
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
|
||||
Reference in New Issue
Block a user