From fecc64e6461a123d3c409e60a0c29f04fb9d531b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 16 Apr 2025 13:03:04 -0600 Subject: [PATCH] Update hidream defaults, pass additional information to flow guidance --- config/examples/train_lora_hidream_48.yaml | 14 ++++++++++---- toolkit/guidance.py | 2 ++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/config/examples/train_lora_hidream_48.yaml b/config/examples/train_lora_hidream_48.yaml index fa5d4bc9..ea9c7dff 100644 --- a/config/examples/train_lora_hidream_48.yaml +++ b/config/examples/train_lora_hidream_48.yaml @@ -1,4 +1,4 @@ -# HiDream training is still highly experimental. The settings here will take ~36.3GB of vram to train. +# HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train. # It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM # I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized. # HiDream has a mixture of experts that may take special training considerations that I do not @@ -23,8 +23,13 @@ config: # trigger_word: "p3r5on" network: type: "lora" - linear: 16 - linear_alpha: 16 + linear: 32 + linear_alpha: 32 + network_kwargs: + # it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt. + # proper training of it is not fully implemented + ignore_if_contains: + - "ff_i" save: dtype: bfloat16 # precision to save save_every: 250 # save every this many steps @@ -47,8 +52,9 @@ config: train_text_encoder: false # wont work with hidream gradient_checkpointing: true # need the on unless you have a ton of vram noise_scheduler: "flowmatch" # for training only + timestep_type: shift # sigmoid, shift, linear optimizer: "adamw8bit" - lr: 1e-4 + lr: 2e-4 # uncomment this to skip the pre training sample # skip_first_sample: true # uncomment to completely disable sampling diff --git a/toolkit/guidance.py b/toolkit/guidance.py index 84242423..287d17e7 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -649,11 +649,13 @@ def targeted_flow_guidance( noise, timesteps ).detach() + unconditional_noisy_latents = sd.condition_noisy_latents(unconditional_noisy_latents, batch) conditional_noisy_latents = sd.add_noise( conditional_latents, noise, timesteps ).detach() + conditional_noisy_latents = sd.condition_noisy_latents(conditional_noisy_latents, batch) # disable the lora to get a baseline prediction sd.network.is_active = False