mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Update hidream defaults, pass additional information to flow guidance
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# HiDream training is still highly experimental. The settings here will take ~36.3GB of vram to train.
|
||||
# HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train.
|
||||
# It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
|
||||
# I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
|
||||
# HiDream has a mixture of experts that may take special training considerations that I do not
|
||||
@@ -23,8 +23,13 @@ config:
|
||||
# trigger_word: "p3r5on"
|
||||
network:
|
||||
type: "lora"
|
||||
linear: 16
|
||||
linear_alpha: 16
|
||||
linear: 32
|
||||
linear_alpha: 32
|
||||
network_kwargs:
|
||||
# it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt.
|
||||
# proper training of it is not fully implemented
|
||||
ignore_if_contains:
|
||||
- "ff_i"
|
||||
save:
|
||||
dtype: bfloat16 # precision to save
|
||||
save_every: 250 # save every this many steps
|
||||
@@ -47,8 +52,9 @@ config:
|
||||
train_text_encoder: false # wont work with hidream
|
||||
gradient_checkpointing: true # need the on unless you have a ton of vram
|
||||
noise_scheduler: "flowmatch" # for training only
|
||||
timestep_type: shift # sigmoid, shift, linear
|
||||
optimizer: "adamw8bit"
|
||||
lr: 1e-4
|
||||
lr: 2e-4
|
||||
# uncomment this to skip the pre training sample
|
||||
# skip_first_sample: true
|
||||
# uncomment to completely disable sampling
|
||||
|
||||
@@ -649,11 +649,13 @@ def targeted_flow_guidance(
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch)
|
||||
conditional_noisy_latents = sd.add_noise(
|
||||
conditional_latents,
|
||||
noise,
|
||||
timesteps
|
||||
).detach()
|
||||
conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch)
|
||||
|
||||
# disable the lora to get a baseline prediction
|
||||
sd.network.is_active = False
|
||||
|
||||
Reference in New Issue
Block a user