A lot of pixart sigma training tweaks

This commit is contained in:
Jaret Burkett
2024-07-28 11:23:18 -06:00
parent 80aa2dbb80
commit 0bc4d555c7
8 changed files with 118 additions and 29 deletions

View File

@@ -51,6 +51,9 @@ resolutions_1024: List[BucketResolution] = [
{"width": 512, "height": 1920},
{"width": 512, "height": 1984},
{"width": 512, "height": 2048},
# extra wides
{"width": 8192, "height": 128},
{"width": 128, "height": 8192},
]
# Even numbers so they can be patched easier

View File

@@ -128,6 +128,8 @@ class NetworkConfig:
if self.lorm_config.do_conv:
self.conv = 4
self.transformer_only = kwargs.get('transformer_only', False)
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']

View File

@@ -169,6 +169,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
network_type: str = "lora",
full_train_in_out: bool = False,
transformer_only: bool = False,
**kwargs
) -> None:
"""
@@ -193,6 +194,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if ignore_if_contains is None:
ignore_if_contains = []
self.ignore_if_contains = ignore_if_contains
self.transformer_only = transformer_only
self.only_if_contains: Union[List, None] = only_if_contains
@@ -271,6 +273,15 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
lora_name = [prefix, name, child_name]
# filter out blank
lora_name = [x for x in lora_name if x and x != ""]
lora_name = ".".join(lora_name)
# if it doesnt have a name, it wil have two dots
lora_name.replace("..", ".")
lora_name = lora_name.replace(".", "_")
skip = False
if any([word in child_name for word in self.ignore_if_contains]):
skip = True
@@ -279,9 +290,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if count_parameters(child_module) < parameter_threshold:
skip = True
if self.transformer_only and self.is_pixart and is_unet:
if "transformer_blocks" not in lora_name:
skip = True
if (is_linear or is_conv2d) and not skip:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
if self.only_if_contains is not None and not any([word in lora_name for word in self.only_if_contains]):
continue
@@ -356,8 +369,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
index = None
print(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder,
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
if self.is_pixart:
replace_modules = ["T5EncoderModel"]
text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")

View File

@@ -516,6 +516,9 @@ class ToolkitNetworkMixin:
load_sd = OrderedDict()
for key, value in weights_sd.items():
load_key = keymap[key] if key in keymap else key
# replace old double __ with single _
if self.is_pixart:
load_key = load_key.replace('__', '_')
load_sd[load_key] = value
# extract extra items from state dict

View File

@@ -169,15 +169,6 @@ class StableDiffusion:
if self.is_loaded:
return
dtype = get_torch_dtype(self.dtype)
# sch = KDPM2DiscreteScheduler
if self.noise_scheduler is None:
scheduler = get_sampler(
'ddpm', {
"prediction_type": self.prediction_type,
},
'sd' if not self.is_pixart else 'pixart'
)
self.noise_scheduler = scheduler
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
# self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
@@ -190,9 +181,10 @@ class StableDiffusion:
from toolkit.civitai import get_model_path_from_url
model_path = get_model_path_from_url(self.model_config.name_or_path)
load_args = {
'scheduler': self.noise_scheduler,
}
load_args = {}
if self.noise_scheduler:
load_args['scheduler'] = self.noise_scheduler
if self.model_config.vae_path is not None:
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
@@ -290,6 +282,7 @@ class StableDiffusion:
device=self.device_torch,
torch_dtype=self.torch_dtype,
text_encoder_3=text_encoder3,
**load_args
)
flush()
@@ -387,6 +380,8 @@ class StableDiffusion:
tokenizer = pipe.tokenizer
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
if self.noise_scheduler is None:
self.noise_scheduler = pipe.scheduler
elif self.model_config.is_auraflow: