Fix issue with full finetuning wan

This commit is contained in:
Jaret Burkett
2025-03-17 09:17:40 -06:00
parent 6cde96ae5f
commit 604e76d34d

View File

@@ -1181,15 +1181,10 @@ class BaseModel:
vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
unet_lr = unet_lr if unet_lr is not None else default_lr
params = []
if self.is_pixart or self.is_auraflow or self.is_flux or self.is_v3 or self.is_lumina2:
for param in named_params.values():
if param.requires_grad:
params.append(param)
else:
for key, diffusers_key in ldm_diffusers_keymap.items():
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
if named_params[diffusers_key].requires_grad:
params.append(named_params[diffusers_key])
for param in named_params.values():
if param.requires_grad:
params.append(param)
param_data = {"params": params, "lr": unet_lr}
trainable_parameters.append(param_data)
print_acc(f"Found {len(params)} trainable parameter in unet")