diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index 998b2312..4932b51d 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -654,10 +654,10 @@ class Wan21(BaseModel): return latents.to(device, dtype=dtype) def get_model_has_grad(self): - return self.model.proj_out.weight.requires_grad + return False def get_te_has_grad(self): - return self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + return False def save_model(self, output_path, meta, save_dtype): # only save the unet