mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-13 17:55:42 +00:00
fix for thor (#3224)
This commit is contained in:
@@ -2515,7 +2515,11 @@ class BlackwellMultiHeadLatentAttentionForwardFP16:
|
||||
# reduction for row_max
|
||||
row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0)
|
||||
|
||||
elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f):
|
||||
elif cutlass.const_expr(
|
||||
(arch >= Arch.sm_101 and arch <= Arch.sm_101f)
|
||||
or (arch >= Arch.sm_103 and arch <= Arch.sm_103f)
|
||||
or (arch >= Arch.sm_110 and arch <= Arch.sm_110f)
|
||||
):
|
||||
tmem_load_red_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.LdRed32x32bOp(
|
||||
tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX
|
||||
|
||||
@@ -2511,7 +2511,11 @@ class BlackwellMultiHeadLatentAttentionForwardFP8:
|
||||
)
|
||||
# reduction for row_max
|
||||
row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0)
|
||||
elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f):
|
||||
elif cutlass.const_expr(
|
||||
(arch >= Arch.sm_101 and arch <= Arch.sm_101f)
|
||||
or (arch >= Arch.sm_103 and arch <= Arch.sm_103f)
|
||||
or (arch >= Arch.sm_110 and arch <= Arch.sm_110f)
|
||||
):
|
||||
tmem_load_red_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.LdRed32x32bOp(
|
||||
tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX
|
||||
|
||||
Reference in New Issue
Block a user