Add support for training Z-Image Turbo with a de-distill training adapter

This commit is contained in:
Jaret Burkett
2025-11-28 08:08:53 -07:00
parent 21bb8a2bf4
commit 4e62c38df5
11 changed files with 459 additions and 7 deletions

View File

@@ -376,6 +376,11 @@ class ToolkitModuleMixin:
if hasattr(self, 'scalar'):
scale = scale * self.scalar
weight_device = weight.device
if weight.device != down_weight.device:
weight = weight.to(down_weight.device)
if scale.device != down_weight.device:
scale = scale.to(down_weight.device)
# merge weight
if self.full_rank:
weight = weight + multiplier * down_weight * scale
@@ -397,7 +402,7 @@ class ToolkitModuleMixin:
weight = weight + multiplier * conved * scale
# set weight to org_module
org_sd[weight_key] = weight.to(orig_dtype)
org_sd[weight_key] = weight.to(weight_device, orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):