mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
Update blackwell tutorial to be compatible with 4.5-dev version (#3130)
* Update blackwell tutorial to be compatible with 4.5-dev version * update example for reverted changes * add more example fix
This commit is contained in:
@@ -647,7 +647,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
||||
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||||
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||||
tmem_dealloc_mbar_ptr: cutlass.Int64
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding_buf: cutlass.Int32
|
||||
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
||||
sC: cute.struct.Align[
|
||||
@@ -826,11 +826,11 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
|
||||
# Tensor memory dealloc barrier init
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buf,
|
||||
storage.tmem_holding_buf.ptr,
|
||||
barrier_for_retrieve=self.tmem_alloc_barrier,
|
||||
allocator_warp_id=self.epilog_warp_id[0],
|
||||
is_two_cta=use_2cta_instrs,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
|
||||
)
|
||||
|
||||
# Cluster arrive after barrier init
|
||||
|
||||
@@ -648,7 +648,7 @@ class PersistentDenseGemmKernel:
|
||||
acc_full_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_acc_stage * 2
|
||||
]
|
||||
tmem_dealloc_mbar_ptr: cutlass.Int64
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding_buf: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator()
|
||||
@@ -699,11 +699,11 @@ class PersistentDenseGemmKernel:
|
||||
)
|
||||
# Tensor memory dealloc barrier init
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buf,
|
||||
storage.tmem_holding_buf.ptr,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=self.epilog_warp_id[0],
|
||||
is_two_cta=use_2cta_instrs,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
|
||||
)
|
||||
|
||||
# Cluster arrive after barrier init
|
||||
|
||||
@@ -219,11 +219,11 @@ def kernel(
|
||||
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buffer,
|
||||
storage.tmem_holding_buffer.ptr,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=epilogue_warp_ids[0],
|
||||
is_two_cta=True if use_2cta_instrs else False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
|
||||
)
|
||||
|
||||
# Partition tensors for TMA; This requires the tensors partitioned for MMA
|
||||
|
||||
@@ -152,11 +152,11 @@ def kernel(
|
||||
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buffer,
|
||||
storage.tmem_holding_buffer.ptr,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=epilogue_warp_ids[0],
|
||||
is_two_cta=True,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
|
||||
@@ -159,11 +159,11 @@ def kernel(
|
||||
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buffer,
|
||||
storage.tmem_holding_buffer.ptr,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=epilogue_warp_ids[0],
|
||||
is_two_cta=True,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
|
||||
@@ -184,11 +184,11 @@ def cluster_specific_kernel(
|
||||
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buffer,
|
||||
storage.tmem_holding_buffer.ptr,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=epilogue_warp_ids[0],
|
||||
is_two_cta=True,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
|
||||
@@ -171,11 +171,11 @@ def kernel(
|
||||
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buffer,
|
||||
storage.tmem_holding_buffer.ptr,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=epilogue_warp_ids[0],
|
||||
is_two_cta=True,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
|
||||
@@ -214,11 +214,11 @@ def gemm(
|
||||
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buffer,
|
||||
storage.tmem_holding_buffer.ptr,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=epilogue_warp_ids[0],
|
||||
is_two_cta=True,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
|
||||
@@ -756,7 +756,7 @@ class PersistentDenseGemmKernel:
|
||||
acc_full_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_acc_stage * 2
|
||||
]
|
||||
tmem_dealloc_mbar_ptr: cutlass.Int64
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding_buf: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator()
|
||||
@@ -806,11 +806,11 @@ class PersistentDenseGemmKernel:
|
||||
)
|
||||
# Tensor memory dealloc barrier init
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buf,
|
||||
storage.tmem_holding_buf.ptr,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=self.epilog_warp_id[0],
|
||||
is_two_cta=use_2cta_instrs,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
|
||||
)
|
||||
|
||||
# Cluster arrive after barrier init
|
||||
|
||||
@@ -672,7 +672,7 @@ class PersistentDenseGemmKernel:
|
||||
acc_full_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_acc_stage * 2
|
||||
]
|
||||
tmem_dealloc_mbar_ptr: cutlass.Int64
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding_buf: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator()
|
||||
@@ -723,11 +723,11 @@ class PersistentDenseGemmKernel:
|
||||
)
|
||||
# Tensor memory dealloc barrier init
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buf,
|
||||
storage.tmem_holding_buf.ptr,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=self.epilogue_warp_id[0],
|
||||
is_two_cta=use_2cta_instrs,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr,
|
||||
)
|
||||
|
||||
# Cluster arrive after barrier init
|
||||
|
||||
@@ -541,7 +541,7 @@ class PersistentDenseGemmKernel:
|
||||
ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
||||
acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||||
acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||||
tmem_dealloc_mbar_ptr: cutlass.Int64
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding_buf: cutlass.Int32
|
||||
# (EPI_TILE_M, EPI_TILE_N, STAGE)
|
||||
sC: cute.struct.Align[
|
||||
@@ -660,8 +660,8 @@ class PersistentDenseGemmKernel:
|
||||
smem = utils.SmemAllocator()
|
||||
storage = smem.allocate(self.shared_storage)
|
||||
|
||||
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr
|
||||
tmem_holding_buf = storage.tmem_holding_buf
|
||||
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar.ptr
|
||||
tmem_holding_buf = storage.tmem_holding_buf.ptr
|
||||
|
||||
# Initialize mainloop ab_pipeline (barrier) and states
|
||||
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
|
||||
@@ -369,7 +369,7 @@
|
||||
" num_threads=threads_per_cta,\n",
|
||||
" )\n",
|
||||
" tmem = utils.TmemAllocator(\n",
|
||||
" storage.tmem_holding_buf,\n",
|
||||
" storage.tmem_holding_buf.ptr,\n",
|
||||
" barrier_for_retrieve=tmem_alloc_barrier,\n",
|
||||
" )\n",
|
||||
" num_tmem_cols = 512\n",
|
||||
@@ -742,7 +742,7 @@
|
||||
" num_threads=threads_per_cta,\n",
|
||||
" )\n",
|
||||
" tmem = utils.TmemAllocator(\n",
|
||||
" storage.tmem_holding_buf,\n",
|
||||
" storage.tmem_holding_buf.ptr,\n",
|
||||
" barrier_for_retrieve=tmem_alloc_barrier,\n",
|
||||
" )\n",
|
||||
" num_tmem_cols = 512\n",
|
||||
|
||||
Reference in New Issue
Block a user