mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue with full finetuning wan
This commit is contained in:
@@ -1181,15 +1181,10 @@ class BaseModel:
|
||||
vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
|
||||
unet_lr = unet_lr if unet_lr is not None else default_lr
|
||||
params = []
|
||||
if self.is_pixart or self.is_auraflow or self.is_flux or self.is_v3 or self.is_lumina2:
|
||||
for param in named_params.values():
|
||||
if param.requires_grad:
|
||||
params.append(param)
|
||||
else:
|
||||
for key, diffusers_key in ldm_diffusers_keymap.items():
|
||||
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
||||
if named_params[diffusers_key].requires_grad:
|
||||
params.append(named_params[diffusers_key])
|
||||
for param in named_params.values():
|
||||
if param.requires_grad:
|
||||
params.append(param)
|
||||
|
||||
param_data = {"params": params, "lr": unet_lr}
|
||||
trainable_parameters.append(param_data)
|
||||
print_acc(f"Found {len(params)} trainable parameter in unet")
|
||||
|
||||
Reference in New Issue
Block a user