mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Add support for training Z-Image Turbo with a de-distill training adapter
This commit is contained in:
@@ -805,7 +805,11 @@ class BaseModel:
|
||||
|
||||
# check if batch size of embeddings matches batch size of latents
|
||||
if isinstance(text_embeddings.text_embeds, list):
|
||||
te_batch_size = text_embeddings.text_embeds[0].shape[0]
|
||||
if len(text_embeddings.text_embeds) == latents.shape[0]:
|
||||
# handle list of embeddings
|
||||
te_batch_size = len(text_embeddings.text_embeds)
|
||||
else:
|
||||
te_batch_size = text_embeddings.text_embeds[0].shape[0]
|
||||
else:
|
||||
te_batch_size = text_embeddings.text_embeds.shape[0]
|
||||
if latents.shape[0] == te_batch_size:
|
||||
|
||||
@@ -376,6 +376,11 @@ class ToolkitModuleMixin:
|
||||
if hasattr(self, 'scalar'):
|
||||
scale = scale * self.scalar
|
||||
|
||||
weight_device = weight.device
|
||||
if weight.device != down_weight.device:
|
||||
weight = weight.to(down_weight.device)
|
||||
if scale.device != down_weight.device:
|
||||
scale = scale.to(down_weight.device)
|
||||
# merge weight
|
||||
if self.full_rank:
|
||||
weight = weight + multiplier * down_weight * scale
|
||||
@@ -397,7 +402,7 @@ class ToolkitModuleMixin:
|
||||
weight = weight + multiplier * conved * scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd[weight_key] = weight.to(orig_dtype)
|
||||
org_sd[weight_key] = weight.to(weight_device, orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
|
||||
|
||||
@@ -72,7 +72,10 @@ class PromptEmbeds:
|
||||
if self.pooled_embeds is not None:
|
||||
prompt_embeds = PromptEmbeds([cloned_text_embeds, self.pooled_embeds.clone()])
|
||||
else:
|
||||
prompt_embeds = PromptEmbeds(cloned_text_embeds)
|
||||
if isinstance(cloned_text_embeds, list) or isinstance(cloned_text_embeds, tuple):
|
||||
prompt_embeds = PromptEmbeds([cloned_text_embeds, None])
|
||||
else:
|
||||
prompt_embeds = PromptEmbeds(cloned_text_embeds)
|
||||
|
||||
if self.attention_mask is not None:
|
||||
if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple):
|
||||
|
||||
Reference in New Issue
Block a user