Wokr on lumina2

This commit is contained in:
Jaret Burkett
2025-02-08 14:52:39 -07:00
parent d138f07365
commit 9a7266275d
3 changed files with 34 additions and 11 deletions

View File

@@ -63,7 +63,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
torch.nn.Module.__init__(self)
self.lora_name = lora_name
self.orig_module_ref = weakref.ref(org_module)
self.scalar = torch.tensor(1.0)
self.scalar = torch.tensor(1.0, device=org_module.weight.device)
# check if parent has bias. if not force use_bias to False
if org_module.bias is None:
use_bias = False
@@ -275,7 +275,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
unet_prefix = self.LORA_PREFIX_UNET
if self.peft_format:
unet_prefix = self.PEFT_PREFIX_UNET
if is_pixart or is_v3 or is_auraflow or is_flux:
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2:
unet_prefix = f"lora_transformer"
if self.peft_format:
unet_prefix = "transformer"