diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 1a419cc8..8e7c2c2a 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -448,13 +448,19 @@ class StableDiffusion: elif self.model_config.is_flux: print("Loading Flux model") - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler") + base_model_path = "/home/jaret/Dev/models/hf/FLUX.1-schnell/" + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") print("Loading vae") - vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype) + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() print("Loading transformer") + subfolder = 'transformer' + transformer_path = model_path + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') - transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype) + transformer = FluxTransformer2DModel.from_pretrained(transformer_path, subfolder=subfolder, torch_dtype=dtype) transformer.to(self.device_torch, dtype=dtype) flush() @@ -465,20 +471,20 @@ class StableDiffusion: flush() print("Loading t5") - text_encoder_2 = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder_2", torch_dtype=dtype) - tokenizer_2 = T5TokenizerFast.from_pretrained(model_path, subfolder="tokenizer_2", torch_dtype=dtype) + tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) + text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) + text_encoder_2.to(self.device_torch, dtype=dtype) flush() - if self.model_config.quantize: - print("Quantizing T5") - quantize(text_encoder_2, weights=qfloat8) - freeze(text_encoder_2) + print("Quantizing T5") + quantize(text_encoder_2, weights=qfloat8) + freeze(text_encoder_2) flush() print("Loading clip") - text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype) - tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype) print("making pipe") @@ -772,7 +778,7 @@ class StableDiffusion: tokenizer_2=self.tokenizer[1], scheduler=noise_scheduler, **extra_args - ).to(self.device_torch) + ) pipeline.watermark = None elif self.is_v3: pipeline = Pipe( @@ -1910,7 +1916,7 @@ class StableDiffusion: else: latents = self.vae.encode(images).latent_dist.sample() # latents = self.vae.encode(images, return_dict=False)[0] - latents = latents * self.vae.config['scaling_factor'] + latents = latents * (self.vae.config['scaling_factor'] - self.vae.config['shift_factor']) latents = latents.to(device, dtype=dtype) return latents @@ -1930,7 +1936,7 @@ class StableDiffusion: if self.vae.device == 'cpu': self.vae.to(self.device) latents = latents.to(device, dtype=dtype) - latents = latents / self.vae.config['scaling_factor'] + latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] images = self.vae.decode(latents).sample images = images.to(device, dtype=dtype) @@ -2031,8 +2037,30 @@ class StableDiffusion: for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): named_params[name] = param if unet: - for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): - named_params[name] = param + if self.is_flux: + # Just train the middle 2 blocks of each transformer block + block_list = [] + num_transformer_blocks = 2 + start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2) + for i in range(num_transformer_blocks): + block_list.append(self.unet.transformer_blocks[start_block + i]) + + num_single_transformer_blocks = 4 + start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2) + for i in range(num_single_transformer_blocks): + block_list.append(self.unet.single_transformer_blocks[start_block + i]) + + for block in block_list: + for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + + # for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + # named_params[name] = param + # for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + # named_params[name] = param + else: + for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param if refiner: for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): @@ -2127,10 +2155,19 @@ class StableDiffusion: # safe_serialization=True, # ) # else: - self.pipeline.save_pretrained( - save_directory=output_file, - safe_serialization=True, - ) + if self.is_flux: + # only save the unet + transformer: FluxTransformer2DModel = self.unet + transformer.save_pretrained( + save_directory=os.path.join(output_file, 'transformer'), + safe_serialization=True, + ) + else: + + self.pipeline.save_pretrained( + save_directory=output_file, + safe_serialization=True, + ) # save out meta config meta_path = os.path.join(output_file, 'aitk_meta.yaml') with open(meta_path, 'w') as f: