Update hidream defaults, pass additional information to flow guidance

This commit is contained in:
Jaret Burkett
2025-04-16 13:03:04 -06:00
parent d5a64006b5
commit fecc64e646
2 changed files with 12 additions and 4 deletions

View File

@@ -1,4 +1,4 @@
# HiDream training is still highly experimental. The settings here will take ~36.3GB of vram to train.
# HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train.
# It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
# I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
# HiDream has a mixture of experts that may take special training considerations that I do not
@@ -23,8 +23,13 @@ config:
# trigger_word: "p3r5on"
network:
type: "lora"
linear: 16
linear_alpha: 16
linear: 32
linear_alpha: 32
network_kwargs:
# it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt.
# proper training of it is not fully implemented
ignore_if_contains:
- "ff_i"
save:
dtype: bfloat16 # precision to save
save_every: 250 # save every this many steps
@@ -47,8 +52,9 @@ config:
train_text_encoder: false # wont work with hidream
gradient_checkpointing: true # need the on unless you have a ton of vram
noise_scheduler: "flowmatch" # for training only
timestep_type: shift # sigmoid, shift, linear
optimizer: "adamw8bit"
lr: 1e-4
lr: 2e-4
# uncomment this to skip the pre training sample
# skip_first_sample: true
# uncomment to completely disable sampling

View File

@@ -649,11 +649,13 @@ def targeted_flow_guidance(
noise,
timesteps
).detach()
unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch)
conditional_noisy_latents = sd.add_noise(
conditional_latents,
noise,
timesteps
).detach()
conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch)
# disable the lora to get a baseline prediction
sd.network.is_active = False