mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Add support for training Z-Image Turbo with a de-distill training adapter
This commit is contained in:
@@ -376,6 +376,11 @@ class ToolkitModuleMixin:
|
||||
if hasattr(self, 'scalar'):
|
||||
scale = scale * self.scalar
|
||||
|
||||
weight_device = weight.device
|
||||
if weight.device != down_weight.device:
|
||||
weight = weight.to(down_weight.device)
|
||||
if scale.device != down_weight.device:
|
||||
scale = scale.to(down_weight.device)
|
||||
# merge weight
|
||||
if self.full_rank:
|
||||
weight = weight + multiplier * down_weight * scale
|
||||
@@ -397,7 +402,7 @@ class ToolkitModuleMixin:
|
||||
weight = weight + multiplier * conved * scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd[weight_key] = weight.to(orig_dtype)
|
||||
org_sd[weight_key] = weight.to(weight_device, orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
|
||||
|
||||
Reference in New Issue
Block a user