mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Setup to retrain guidance embedding for flux. Use defualt timestep distribution for flux
This commit is contained in:
@@ -464,7 +464,13 @@ class StableDiffusion:
|
||||
subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
|
||||
transformer = FluxTransformer2DModel.from_pretrained(transformer_path, subfolder=subfolder, torch_dtype=dtype)
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
transformer_path,
|
||||
subfolder=subfolder,
|
||||
torch_dtype=dtype,
|
||||
low_cpu_mem_usage=False,
|
||||
device_map=None
|
||||
)
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
@@ -1609,7 +1615,6 @@ class StableDiffusion:
|
||||
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
|
||||
)
|
||||
|
||||
# todo we do this on sd3 training. I think we do it here too? No paper
|
||||
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
||||
elif self.is_v3:
|
||||
noise_pred = self.unet(
|
||||
@@ -2053,6 +2058,12 @@ class StableDiffusion:
|
||||
# for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
|
||||
# train the guidance embedding
|
||||
if self.unet.config.guidance_embeds:
|
||||
transformer: FluxTransformer2DModel = self.unet
|
||||
for name, param in transformer.time_text_embed.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
|
||||
Reference in New Issue
Block a user