Stream-K Reduction option as Runtime parameter and Compilation Error Fix (SK- Reduction) (#2145)

* reduction is passed as runtime parameter

* clang

* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp

Co-authored-by: John Afaganis <john.afaganis@amd.com>

* Update include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp


* remove comment

---------


[ROCm/composable_kernel commit: 6fad1c4874]
This commit is contained in:
Muhammed Emin Ozturk
2025-06-11 10:59:44 -07:00
committed by GitHub
parent 46624a1abd
commit 6111449cd6
7 changed files with 216 additions and 101 deletions

View File

@@ -1415,12 +1415,11 @@ template <uint32_t MPerBlock_,
index_t M01_ = 4>
struct BlockToCTileMap_GemmStreamK_v2
{
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//--------------------------------------
// pass to device
@@ -1433,10 +1432,17 @@ struct BlockToCTileMap_GemmStreamK_v2
MDiv k_iters_per_tile;
MDiv equiv_tiles_big; // for reduction
MDiv equiv_tiles_little; // for reduction
StreamKReductionStrategy reduction_strategy;
// prefer construct on host
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(
uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
uint32_t m,
uint32_t n,
uint32_t k,
uint32_t grid_size = 1,
uint32_t streamk_sel = 1,
StreamKReductionStrategy reduction_strategy_ = StreamKReductionStrategy::Atomic)
: reduction_strategy(reduction_strategy_)
{
// total output tiles
@@ -1546,7 +1552,7 @@ struct BlockToCTileMap_GemmStreamK_v2
// Using multiple blocks for parallel reduction
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
if(reduction_strategy == ck::StreamKReductionStrategy::Reduction)
{
// Add additional safety checks
if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0)
@@ -1589,7 +1595,7 @@ struct BlockToCTileMap_GemmStreamK_v2
__host__ __device__ index_t get_grid_dims() const
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
if(reduction_strategy == StreamKReductionStrategy::Reduction)
{
// return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
return reduction_start_block_idx + get_sk_tiles();

View File

@@ -513,7 +513,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
index_t StrideB_,
index_t StrideC_,
index_t Streamk_sel_,
index_t Grid_size_)
index_t Grid_size_,
StreamKReductionStrategy reduction_strategy_)
: M{M_},
N{N_},
K{K_},
@@ -522,6 +523,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
StrideC{StrideC_},
Streamk_sel{Streamk_sel_},
Grid_size{Grid_size_},
reduction_strategy{reduction_strategy_}, // Initialize the member variable
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KRead{CalculateKRead(K_, 1)},
@@ -550,8 +552,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel
<< ", Grid size:" << Grid_size << "}" << std::endl;
<< "NBlock: " << NBlock << ", "
<< "Stream-K Selection:" << Streamk_sel << ", "
<< "Grid size:" << Grid_size << ", "
<< "Reduction Strategy:"
<< (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic"
: "Reduction")
<< "}" << std::endl;
}
index_t M;
@@ -562,6 +569,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
index_t StrideC;
index_t Streamk_sel;
mutable index_t Grid_size;
StreamKReductionStrategy reduction_strategy;
index_t MPadded;
index_t NPadded;
index_t KRead;
@@ -585,13 +593,26 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
index_t StrideB_,
index_t StrideC_,
index_t Streamk_sel_,
index_t Grid_size_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
index_t Grid_size_,
StreamKReductionStrategy reduction_strategy_)
: Problem{M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
Streamk_sel_,
Grid_size_,
reduction_strategy_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
block_2_ctile_map_streamk(
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
block_2_ctile_map_streamk(M_,
N_,
AK0Number * CalculateKPadded(K_, 1),
Grid_size_,
Streamk_sel_,
reduction_strategy_)
{
}
@@ -1267,11 +1288,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
problem.Streamk_sel,
problem.reduction_strategy);
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop;
@@ -1286,6 +1309,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
@@ -1301,8 +1325,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
@@ -1890,8 +1913,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
@@ -1903,8 +1925,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
else if(problem.reduction_strategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
@@ -1936,8 +1958,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
});
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
@@ -1952,8 +1973,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
@@ -2008,7 +2028,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
problem.Streamk_sel,
problem.reduction_strategy);
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
@@ -2027,8 +2048,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
@@ -2644,8 +2664,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
@@ -2657,8 +2676,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
else if(problem.reduction_strategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
@@ -2693,16 +2712,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{