diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 99546a55..fc2632fa 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -19,7 +19,7 @@ def attention(q, k, v, pe): def rope(pos, dim, theta): - if pos.device.type == "mps": + if pos.device.type == "mps" or pos.device.type == "xpu": scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim else: scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim