diff --git a/jobs/process/ModRescaleLoraProcess.py b/jobs/process/ModRescaleLoraProcess.py index 882ef0e0..ff8304d4 100644 --- a/jobs/process/ModRescaleLoraProcess.py +++ b/jobs/process/ModRescaleLoraProcess.py @@ -84,7 +84,8 @@ class ModRescaleLoraProcess(BaseProcess): if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): # would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN v = v * up_down_scale - new_state_dict[key] = v.to(get_torch_dtype(self.save_dtype)) + v = v.detach().clone().to("cpu").to(self.save_dtype) + new_state_dict[key] = v save_meta = add_model_hash_to_meta(new_state_dict, save_meta) save_file(new_state_dict, self.output_path, save_meta)