Fixed issue with rescaled loras only saving af fp32

This commit is contained in:
Jaret Burkett
2023-08-01 14:08:22 -06:00
parent 8b8d53888d
commit f53fd08690

View File

@@ -84,7 +84,8 @@ class ModRescaleLoraProcess(BaseProcess):
if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
# would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN
v = v * up_down_scale
new_state_dict[key] = v.to(get_torch_dtype(self.save_dtype))
v = v.detach().clone().to("cpu").to(self.save_dtype)
new_state_dict[key] = v
save_meta = add_model_hash_to_meta(new_state_dict, save_meta)
save_file(new_state_dict, self.output_path, save_meta)