mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added training for an experimental decoratgor embedding. Allow for turning off guidance embedding on flux (for unreleased model). Various bug fixes and modifications
This commit is contained in:
35
toolkit/models/flux.py
Normal file
35
toolkit/models/flux.py
Normal file
@@ -0,0 +1,35 @@
|
||||
|
||||
# forward that bypasses the guidance embedding so it can be avoided during training.
|
||||
from functools import partial
|
||||
|
||||
|
||||
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(
|
||||
timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
return conditioning
|
||||
|
||||
# bypass the forward function
|
||||
|
||||
|
||||
def bypass_flux_guidance(transformer):
|
||||
if hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
|
||||
return
|
||||
# dont bypass if it doesnt have the guidance embedding
|
||||
if not hasattr(transformer.time_text_embed, 'guidance_embedder'):
|
||||
return
|
||||
transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward
|
||||
transformer.time_text_embed.forward = partial(
|
||||
guidance_embed_bypass_forward, transformer.time_text_embed
|
||||
)
|
||||
|
||||
# restore the forward function
|
||||
|
||||
|
||||
def restore_flux_guidance(transformer):
|
||||
if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
|
||||
return
|
||||
transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward
|
||||
del transformer.time_text_embed._bfg_orig_forward
|
||||
Reference in New Issue
Block a user