fix for XPU (#1997)

use float32 for XPU, same as previous fix for MPS
This commit is contained in:
DenOfEquity
2024-10-06 14:33:47 +01:00
committed by GitHub
parent 4f7f815b9f
commit 2467c88c50

View File

@@ -19,7 +19,7 @@ def attention(q, k, v, pe):
def rope(pos, dim, theta):
if pos.device.type == "mps":
if pos.device.type == "mps" or pos.device.type == "xpu":
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
else:
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim