From 604e76d34d2c1f032dcdb6cce1eef759b12d0e66 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 17 Mar 2025 09:17:40 -0600 Subject: [PATCH] Fix issue with full finetuning wan --- toolkit/models/base_model.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 56c506da..48a04ca4 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -1181,15 +1181,10 @@ class BaseModel: vae=False, unet=unet, text_encoder=False, state_dict_keys=True) unet_lr = unet_lr if unet_lr is not None else default_lr params = [] - if self.is_pixart or self.is_auraflow or self.is_flux or self.is_v3 or self.is_lumina2: - for param in named_params.values(): - if param.requires_grad: - params.append(param) - else: - for key, diffusers_key in ldm_diffusers_keymap.items(): - if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: - if named_params[diffusers_key].requires_grad: - params.append(named_params[diffusers_key]) + for param in named_params.values(): + if param.requires_grad: + params.append(param) + param_data = {"params": params, "lr": unet_lr} trainable_parameters.append(param_data) print_acc(f"Found {len(params)} trainable parameter in unet")