diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 56c506da..48a04ca4 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -1181,15 +1181,10 @@ class BaseModel: vae=False, unet=unet, text_encoder=False, state_dict_keys=True) unet_lr = unet_lr if unet_lr is not None else default_lr params = [] - if self.is_pixart or self.is_auraflow or self.is_flux or self.is_v3 or self.is_lumina2: - for param in named_params.values(): - if param.requires_grad: - params.append(param) - else: - for key, diffusers_key in ldm_diffusers_keymap.items(): - if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: - if named_params[diffusers_key].requires_grad: - params.append(named_params[diffusers_key]) + for param in named_params.values(): + if param.requires_grad: + params.append(param) + param_data = {"params": params, "lr": unet_lr} trainable_parameters.append(param_data) print_acc(f"Found {len(params)} trainable parameter in unet")