diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py index 8d75ffa67..a9e7c102b 100644 --- a/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py @@ -2515,7 +2515,11 @@ class BlackwellMultiHeadLatentAttentionForwardFP16: # reduction for row_max row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) - elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): + elif cutlass.const_expr( + (arch >= Arch.sm_101 and arch <= Arch.sm_101f) + or (arch >= Arch.sm_103 and arch <= Arch.sm_103f) + or (arch >= Arch.sm_110 and arch <= Arch.sm_110f) + ): tmem_load_red_atom = cute.make_copy_atom( tcgen05.copy.LdRed32x32bOp( tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py index 6693521cf..c43195f9f 100644 --- a/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py @@ -2511,7 +2511,11 @@ class BlackwellMultiHeadLatentAttentionForwardFP8: ) # reduction for row_max row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) - elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): + elif cutlass.const_expr( + (arch >= Arch.sm_101 and arch <= Arch.sm_101f) + or (arch >= Arch.sm_103 and arch <= Arch.sm_103f) + or (arch >= Arch.sm_110 and arch <= Arch.sm_110f) + ): tmem_load_red_atom = cute.make_copy_atom( tcgen05.copy.LdRed32x32bOp( tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX