mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added training for pixart-a
This commit is contained in:
@@ -151,6 +151,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
is_pixart: bool = False,
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
ignore_if_contains = None,
|
||||
@@ -197,6 +198,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.multiplier = multiplier
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_pixart = is_pixart
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
@@ -224,8 +226,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[LoRAModule]:
|
||||
unet_prefix = self.LORA_PREFIX_UNET
|
||||
if is_pixart:
|
||||
unet_prefix = f"lora_transformer"
|
||||
|
||||
prefix = (
|
||||
self.LORA_PREFIX_UNET
|
||||
unet_prefix
|
||||
if is_unet
|
||||
else (
|
||||
self.LORA_PREFIX_TEXT_ENCODER
|
||||
|
||||
Reference in New Issue
Block a user