From 971d1ed8b7205e35827c75fbce3b700fe4150003 Mon Sep 17 00:00:00 2001 From: Observer007 <45875558+Observer007@users.noreply.github.com> Date: Wed, 13 May 2026 09:06:44 +0800 Subject: [PATCH] fix for thor (#3224) --- .../cute/blackwell/kernel/attention/mla/mla_decode_fp16.py | 6 +++++- .../cute/blackwell/kernel/attention/mla/mla_decode_fp8.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py index 8d75ffa67..a9e7c102b 100644 --- a/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py @@ -2515,7 +2515,11 @@ class BlackwellMultiHeadLatentAttentionForwardFP16: # reduction for row_max row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) - elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): + elif cutlass.const_expr( + (arch >= Arch.sm_101 and arch <= Arch.sm_101f) + or (arch >= Arch.sm_103 and arch <= Arch.sm_103f) + or (arch >= Arch.sm_110 and arch <= Arch.sm_110f) + ): tmem_load_red_atom = cute.make_copy_atom( tcgen05.copy.LdRed32x32bOp( tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py index 6693521cf..c43195f9f 100644 --- a/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py @@ -2511,7 +2511,11 @@ class BlackwellMultiHeadLatentAttentionForwardFP8: ) # reduction for row_max row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) - elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): + elif cutlass.const_expr( + (arch >= Arch.sm_101 and arch <= Arch.sm_101f) + or (arch >= Arch.sm_103 and arch <= Arch.sm_103f) + or (arch >= Arch.sm_110 and arch <= Arch.sm_110f) + ): tmem_load_red_atom = cute.make_copy_atom( tcgen05.copy.LdRed32x32bOp( tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX