From 2467c88c507cbc82fa6ff728d1577e7f3aedb7a2 Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Sun, 6 Oct 2024 14:33:47 +0100 Subject: [PATCH] fix for XPU (#1997) use float32 for XPU, same as previous fix for MPS --- backend/nn/flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 99546a55..fc2632fa 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -19,7 +19,7 @@ def attention(q, k, v, pe): def rope(pos, dim, theta): - if pos.device.type == "mps": + if pos.device.type == "mps" or pos.device.type == "xpu": scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim else: scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim