[CuTeDSL] Fix: SM100 block-scale gemm overlapping accumulator (#2995)

* Fix: SM100 block-scale gemm overlapping accumulator

Signed-off-by: Hua Huang <huah@nvidia.com>

* Also include threads_per_warp fix

Signed-off-by: Hua Huang <huah@nvidia.com>

---------

Signed-off-by: Hua Huang <huah@nvidia.com>
This commit is contained in:
Hua Huang
2026-02-03 11:01:41 +08:00
committed by GitHub
parent a4eb0e05f6
commit 1cfbb53a23
2 changed files with 32 additions and 28 deletions

View File

@@ -117,6 +117,10 @@ Constraints:
"""
def ceil_div(a, b):
return (a + b - 1) // b
class Sm100BlockScaledPersistentDenseGemmKernel:
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
@@ -206,17 +210,18 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 32 * len(
self.threads_per_warp = 32
self.threads_per_cta = self.threads_per_warp * len(
(self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id)
)
# Set barrier id for epilogue sync and tmem ptr sync
self.epilog_sync_barrier = pipeline.NamedBarrier(
barrier_id=1,
num_threads=32 * len(self.epilog_warp_id),
num_threads=self.threads_per_warp * len(self.epilog_warp_id),
)
self.tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)),
num_threads=self.threads_per_warp * len((self.mma_warp_id, *self.epilog_warp_id)),
)
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
SM100_TMEM_CAPACITY_COLUMNS = 512
@@ -376,7 +381,9 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage if not self.overlapping_accum else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
# Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue
self.iter_acc_early_release_in_epilogue = self.num_sf_tmem_cols // self.epi_tile_n
# Use -1 since at that iteration the pipeline is updated after the tmem -> reg copy
num_subtiles_in_overlap_region = ceil_div(self.num_sf_tmem_cols, self.epi_tile_n)
self.iter_acc_early_release_in_epilogue = num_subtiles_in_overlap_region - 1
@cute.jit
def __call__(
@@ -748,7 +755,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
# Initialize acc_pipeline (barrier) and states
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_acc_consumer_threads = len(self.epilog_warp_id) * (
num_acc_consumer_threads = self.threads_per_warp * len(self.epilog_warp_id) * (
2 if use_2cta_instrs else 1
)
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
@@ -1359,7 +1366,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
# Threads/warps participating in tma store pipeline
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.epilog_warp_id),
self.threads_per_warp * len(self.epilog_warp_id),
)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage,
@@ -1432,8 +1439,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
if subtile_idx == self.iter_acc_early_release_in_epilogue:
# Fence for TMEM load
cute.arch.fence_view_async_tmem_load()
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
#
@@ -1446,7 +1452,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
#
# Store C to shared memory
#
c_buffer = (num_prev_subtiles + real_subtile_idx) % self.num_c_stage
c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage
cute.copy(
tiled_copy_r2s,
tRS_rC,
@@ -1474,8 +1480,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
# Async arrive accumulator buffer empty
#
if cutlass.const_expr(not self.overlapping_accum):
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
#
@@ -2232,9 +2237,6 @@ def run(
# Create scale factor tensor SFA/SFB
def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype):
def ceil_div(a, b):
return (a + b - 1) // b
sf_k = ceil_div(k, sf_vec_size)
ref_shape = (l, mn, sf_k)

View File

@@ -151,6 +151,10 @@ Constraints:
"""
def ceil_div(a, b):
return (a + b - 1) // b
class Sm100BlockScaledPersistentDenseGemmKernel:
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
@@ -250,17 +254,18 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 32 * len(
self.threads_per_warp = 32
self.threads_per_cta = self.threads_per_warp * len(
(self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id)
)
# Set barrier id for epilogue sync and tmem ptr sync
self.epilog_sync_barrier = pipeline.NamedBarrier(
barrier_id=1,
num_threads=32 * len(self.epilog_warp_id),
num_threads=self.threads_per_warp * len(self.epilog_warp_id),
)
self.tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)),
num_threads=self.threads_per_warp * len((self.mma_warp_id, *self.epilog_warp_id)),
)
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
SM100_TMEM_CAPACITY_COLUMNS = 512
@@ -420,7 +425,9 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage if not self.overlapping_accum else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
# Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue
self.iter_acc_early_release_in_epilogue = self.num_sf_tmem_cols // self.epi_tile_n
# Use -1 since at that iteration the pipeline is updated after the tmem -> reg copy
num_subtiles_in_overlap_region = ceil_div(self.num_sf_tmem_cols, self.epi_tile_n)
self.iter_acc_early_release_in_epilogue = num_subtiles_in_overlap_region - 1
# Set prefetch distance for both initial and rolling prefetch (unified control)
# None = use num_ab_stage (default), 0 = disable prefetch, >0 = explicit distance
@@ -802,7 +809,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
# Initialize acc_pipeline (barrier) and states
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_acc_consumer_threads = len(self.epilog_warp_id) * (
num_acc_consumer_threads = self.threads_per_warp * len(self.epilog_warp_id) * (
2 if use_2cta_instrs else 1
)
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
@@ -1458,7 +1465,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
# Threads/warps participating in tma store pipeline
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.epilog_warp_id),
self.threads_per_warp * len(self.epilog_warp_id),
)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage,
@@ -1531,8 +1538,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
if subtile_idx == self.iter_acc_early_release_in_epilogue:
# Fence for TMEM load
cute.arch.fence_view_async_tmem_load()
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
#
@@ -1545,7 +1551,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
#
# Store C to shared memory
#
c_buffer = (num_prev_subtiles + real_subtile_idx) % self.num_c_stage
c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage
cute.copy(
tiled_copy_r2s,
tRS_rC,
@@ -1573,8 +1579,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
# Async arrive accumulator buffer empty
#
if cutlass.const_expr(not self.overlapping_accum):
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_consumer_state)
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
#
@@ -2340,9 +2345,6 @@ def run(
# Create scale factor tensor SFA/SFB
def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype):
def ceil_div(a, b):
return (a + b - 1) // b
sf_k = ceil_div(k, sf_vec_size)
ref_shape = (l, mn, sf_k)