mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Update hidream defaults, pass additional information to flow guidance
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
# HiDream training is still highly experimental. The settings here will take ~36.3GB of vram to train.
|
# HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train.
|
||||||
# It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
|
# It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
|
||||||
# I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
|
# I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
|
||||||
# HiDream has a mixture of experts that may take special training considerations that I do not
|
# HiDream has a mixture of experts that may take special training considerations that I do not
|
||||||
@@ -23,8 +23,13 @@ config:
|
|||||||
# trigger_word: "p3r5on"
|
# trigger_word: "p3r5on"
|
||||||
network:
|
network:
|
||||||
type: "lora"
|
type: "lora"
|
||||||
linear: 16
|
linear: 32
|
||||||
linear_alpha: 16
|
linear_alpha: 32
|
||||||
|
network_kwargs:
|
||||||
|
# it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt.
|
||||||
|
# proper training of it is not fully implemented
|
||||||
|
ignore_if_contains:
|
||||||
|
- "ff_i"
|
||||||
save:
|
save:
|
||||||
dtype: bfloat16 # precision to save
|
dtype: bfloat16 # precision to save
|
||||||
save_every: 250 # save every this many steps
|
save_every: 250 # save every this many steps
|
||||||
@@ -47,8 +52,9 @@ config:
|
|||||||
train_text_encoder: false # wont work with hidream
|
train_text_encoder: false # wont work with hidream
|
||||||
gradient_checkpointing: true # need the on unless you have a ton of vram
|
gradient_checkpointing: true # need the on unless you have a ton of vram
|
||||||
noise_scheduler: "flowmatch" # for training only
|
noise_scheduler: "flowmatch" # for training only
|
||||||
|
timestep_type: shift # sigmoid, shift, linear
|
||||||
optimizer: "adamw8bit"
|
optimizer: "adamw8bit"
|
||||||
lr: 1e-4
|
lr: 2e-4
|
||||||
# uncomment this to skip the pre training sample
|
# uncomment this to skip the pre training sample
|
||||||
# skip_first_sample: true
|
# skip_first_sample: true
|
||||||
# uncomment to completely disable sampling
|
# uncomment to completely disable sampling
|
||||||
|
|||||||
@@ -649,11 +649,13 @@ def targeted_flow_guidance(
|
|||||||
noise,
|
noise,
|
||||||
timesteps
|
timesteps
|
||||||
).detach()
|
).detach()
|
||||||
|
unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch)
|
||||||
conditional_noisy_latents = sd.add_noise(
|
conditional_noisy_latents = sd.add_noise(
|
||||||
conditional_latents,
|
conditional_latents,
|
||||||
noise,
|
noise,
|
||||||
timesteps
|
timesteps
|
||||||
).detach()
|
).detach()
|
||||||
|
conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch)
|
||||||
|
|
||||||
# disable the lora to get a baseline prediction
|
# disable the lora to get a baseline prediction
|
||||||
sd.network.is_active = False
|
sd.network.is_active = False
|
||||||
|
|||||||
Reference in New Issue
Block a user