Update blackwell tutorial to be compatible with 4.5-dev version (#3130)

* Update blackwell tutorial to be compatible with 4.5-dev version

* update example for reverted changes

* add more example fix
This commit is contained in:
Longsheng Du
2026-04-09 14:40:33 +08:00
committed by GitHub
parent bd01dd3651
commit 08185b9c3e
12 changed files with 29 additions and 29 deletions

View File

@@ -647,7 +647,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
# (EPI_TILE_M, EPI_TILE_N, STAGE)
sC: cute.struct.Align[
@@ -826,11 +826,11 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
# Tensor memory dealloc barrier init
tmem = utils.TmemAllocator(
storage.tmem_holding_buf,
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=self.tmem_alloc_barrier,
allocator_warp_id=self.epilog_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
# Cluster arrive after barrier init

View File

@@ -648,7 +648,7 @@ class PersistentDenseGemmKernel:
acc_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_acc_stage * 2
]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
@@ -699,11 +699,11 @@ class PersistentDenseGemmKernel:
)
# Tensor memory dealloc barrier init
tmem = utils.TmemAllocator(
storage.tmem_holding_buf,
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilog_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
# Cluster arrive after barrier init

View File

@@ -219,11 +219,11 @@ def kernel(
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buffer,
storage.tmem_holding_buffer.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=epilogue_warp_ids[0],
is_two_cta=True if use_2cta_instrs else False,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
# Partition tensors for TMA; This requires the tensors partitioned for MMA

View File

@@ -152,11 +152,11 @@ def kernel(
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buffer,
storage.tmem_holding_buffer.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=epilogue_warp_ids[0],
is_two_cta=True,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
num_tma_copy_bytes = (

View File

@@ -159,11 +159,11 @@ def kernel(
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buffer,
storage.tmem_holding_buffer.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=epilogue_warp_ids[0],
is_two_cta=True,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
num_tma_copy_bytes = (

View File

@@ -184,11 +184,11 @@ def cluster_specific_kernel(
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buffer,
storage.tmem_holding_buffer.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=epilogue_warp_ids[0],
is_two_cta=True,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
num_tma_copy_bytes = (

View File

@@ -171,11 +171,11 @@ def kernel(
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buffer,
storage.tmem_holding_buffer.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=epilogue_warp_ids[0],
is_two_cta=True,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
num_tma_copy_bytes = (

View File

@@ -214,11 +214,11 @@ def gemm(
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buffer,
storage.tmem_holding_buffer.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=epilogue_warp_ids[0],
is_two_cta=True,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
num_tma_copy_bytes = (

View File

@@ -756,7 +756,7 @@ class PersistentDenseGemmKernel:
acc_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_acc_stage * 2
]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
@@ -806,11 +806,11 @@ class PersistentDenseGemmKernel:
)
# Tensor memory dealloc barrier init
tmem = utils.TmemAllocator(
storage.tmem_holding_buf,
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilog_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
# Cluster arrive after barrier init

View File

@@ -672,7 +672,7 @@ class PersistentDenseGemmKernel:
acc_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_acc_stage * 2
]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
@@ -723,11 +723,11 @@ class PersistentDenseGemmKernel:
)
# Tensor memory dealloc barrier init
tmem = utils.TmemAllocator(
storage.tmem_holding_buf,
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
)
# Cluster arrive after barrier init

View File

@@ -541,7 +541,7 @@ class PersistentDenseGemmKernel:
ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_dealloc_mbar: cutlass.Int64
tmem_holding_buf: cutlass.Int32
# (EPI_TILE_M, EPI_TILE_N, STAGE)
sC: cute.struct.Align[
@@ -660,8 +660,8 @@ class PersistentDenseGemmKernel:
smem = utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr
tmem_holding_buf = storage.tmem_holding_buf
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar.ptr
tmem_holding_buf = storage.tmem_holding_buf.ptr
# Initialize mainloop ab_pipeline (barrier) and states
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)

View File

@@ -369,7 +369,7 @@
" num_threads=threads_per_cta,\n",
" )\n",
" tmem = utils.TmemAllocator(\n",
" storage.tmem_holding_buf,\n",
" storage.tmem_holding_buf.ptr,\n",
" barrier_for_retrieve=tmem_alloc_barrier,\n",
" )\n",
" num_tmem_cols = 512\n",
@@ -742,7 +742,7 @@
" num_threads=threads_per_cta,\n",
" )\n",
" tmem = utils.TmemAllocator(\n",
" storage.tmem_holding_buf,\n",
" storage.tmem_holding_buf.ptr,\n",
" barrier_for_retrieve=tmem_alloc_barrier,\n",
" )\n",
" num_tmem_cols = 512\n",