Added training for pixart-a

This commit is contained in:
Jaret Burkett
2024-02-13 16:00:04 -07:00
parent 4ec4025cbb
commit 93b52932c1
10 changed files with 288 additions and 24 deletions

View File

@@ -151,6 +151,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
is_pixart: bool = False,
use_bias: bool = False,
is_lorm: bool = False,
ignore_if_contains = None,
@@ -197,6 +198,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.multiplier = multiplier
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
self.is_pixart = is_pixart
if modules_dim is not None:
print(f"create LoRA network from weights")
@@ -224,8 +226,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
unet_prefix = self.LORA_PREFIX_UNET
if is_pixart:
unet_prefix = f"lora_transformer"
prefix = (
self.LORA_PREFIX_UNET
unet_prefix
if is_unet
else (
self.LORA_PREFIX_TEXT_ENCODER