Add support for training Z-Image Turbo with a de-distill training adapter

This commit is contained in:
Jaret Burkett
2025-11-28 08:08:53 -07:00
parent 21bb8a2bf4
commit 4e62c38df5
11 changed files with 459 additions and 7 deletions

View File

@@ -805,7 +805,11 @@ class BaseModel:
# check if batch size of embeddings matches batch size of latents
if isinstance(text_embeddings.text_embeds, list):
te_batch_size = text_embeddings.text_embeds[0].shape[0]
if len(text_embeddings.text_embeds) == latents.shape[0]:
# handle list of embeddings
te_batch_size = len(text_embeddings.text_embeds)
else:
te_batch_size = text_embeddings.text_embeds[0].shape[0]
else:
te_batch_size = text_embeddings.text_embeds.shape[0]
if latents.shape[0] == te_batch_size:

View File

@@ -376,6 +376,11 @@ class ToolkitModuleMixin:
if hasattr(self, 'scalar'):
scale = scale * self.scalar
weight_device = weight.device
if weight.device != down_weight.device:
weight = weight.to(down_weight.device)
if scale.device != down_weight.device:
scale = scale.to(down_weight.device)
# merge weight
if self.full_rank:
weight = weight + multiplier * down_weight * scale
@@ -397,7 +402,7 @@ class ToolkitModuleMixin:
weight = weight + multiplier * conved * scale
# set weight to org_module
org_sd[weight_key] = weight.to(orig_dtype)
org_sd[weight_key] = weight.to(weight_device, orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):

View File

@@ -72,7 +72,10 @@ class PromptEmbeds:
if self.pooled_embeds is not None:
prompt_embeds = PromptEmbeds([cloned_text_embeds, self.pooled_embeds.clone()])
else:
prompt_embeds = PromptEmbeds(cloned_text_embeds)
if isinstance(cloned_text_embeds, list) or isinstance(cloned_text_embeds, tuple):
prompt_embeds = PromptEmbeds([cloned_text_embeds, None])
else:
prompt_embeds = PromptEmbeds(cloned_text_embeds)
if self.attention_mask is not None:
if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple):