mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
Stream-K Reduction option as Runtime parameter and Compilation Error Fix (SK- Reduction) (#2145)
* reduction is passed as runtime parameter
* clang
* Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
Co-authored-by: John Afaganis <john.afaganis@amd.com>
* Update include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
* remove comment
---------
[ROCm/composable_kernel commit: 6fad1c4874]
This commit is contained in:
committed by
GitHub
parent
46624a1abd
commit
6111449cd6
@@ -1415,12 +1415,11 @@ template <uint32_t MPerBlock_,
|
||||
index_t M01_ = 4>
|
||||
struct BlockToCTileMap_GemmStreamK_v2
|
||||
{
|
||||
static constexpr uint32_t min_k_iters_per_sk_block = 2;
|
||||
static constexpr uint32_t MPerBlock = MPerBlock_;
|
||||
static constexpr uint32_t NPerBlock = NPerBlock_;
|
||||
static constexpr uint32_t KPerBlock = KPerBlock_;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
|
||||
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
|
||||
static constexpr uint32_t min_k_iters_per_sk_block = 2;
|
||||
static constexpr uint32_t MPerBlock = MPerBlock_;
|
||||
static constexpr uint32_t NPerBlock = NPerBlock_;
|
||||
static constexpr uint32_t KPerBlock = KPerBlock_;
|
||||
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
|
||||
|
||||
//--------------------------------------
|
||||
// pass to device
|
||||
@@ -1433,10 +1432,17 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
MDiv k_iters_per_tile;
|
||||
MDiv equiv_tiles_big; // for reduction
|
||||
MDiv equiv_tiles_little; // for reduction
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
|
||||
// prefer construct on host
|
||||
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(
|
||||
uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
|
||||
uint32_t m,
|
||||
uint32_t n,
|
||||
uint32_t k,
|
||||
uint32_t grid_size = 1,
|
||||
uint32_t streamk_sel = 1,
|
||||
StreamKReductionStrategy reduction_strategy_ = StreamKReductionStrategy::Atomic)
|
||||
: reduction_strategy(reduction_strategy_)
|
||||
{
|
||||
|
||||
// total output tiles
|
||||
@@ -1546,7 +1552,7 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
// Using multiple blocks for parallel reduction
|
||||
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
|
||||
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
if(reduction_strategy == ck::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// Add additional safety checks
|
||||
if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0)
|
||||
@@ -1589,7 +1595,7 @@ struct BlockToCTileMap_GemmStreamK_v2
|
||||
|
||||
__host__ __device__ index_t get_grid_dims() const
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
if(reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
|
||||
return reduction_start_block_idx + get_sk_tiles();
|
||||
|
||||
75
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Normal file → Executable file
75
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Normal file → Executable file
@@ -513,7 +513,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t Streamk_sel_,
|
||||
index_t Grid_size_)
|
||||
index_t Grid_size_,
|
||||
StreamKReductionStrategy reduction_strategy_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
@@ -522,6 +523,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
StrideC{StrideC_},
|
||||
Streamk_sel{Streamk_sel_},
|
||||
Grid_size{Grid_size_},
|
||||
reduction_strategy{reduction_strategy_}, // Initialize the member variable
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
KRead{CalculateKRead(K_, 1)},
|
||||
@@ -550,8 +552,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
<< "AK0:" << AK0 << ", "
|
||||
<< "BK0:" << BK0 << ", "
|
||||
<< "MBlock: " << MBlock << ", "
|
||||
<< "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel
|
||||
<< ", Grid size:" << Grid_size << "}" << std::endl;
|
||||
<< "NBlock: " << NBlock << ", "
|
||||
<< "Stream-K Selection:" << Streamk_sel << ", "
|
||||
<< "Grid size:" << Grid_size << ", "
|
||||
<< "Reduction Strategy:"
|
||||
<< (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic"
|
||||
: "Reduction")
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -562,6 +569,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
index_t StrideC;
|
||||
index_t Streamk_sel;
|
||||
mutable index_t Grid_size;
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
index_t KRead;
|
||||
@@ -585,13 +593,26 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
index_t StrideB_,
|
||||
index_t StrideC_,
|
||||
index_t Streamk_sel_,
|
||||
index_t Grid_size_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
|
||||
index_t Grid_size_,
|
||||
StreamKReductionStrategy reduction_strategy_)
|
||||
: Problem{M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
Streamk_sel_,
|
||||
Grid_size_,
|
||||
reduction_strategy_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
block_2_ctile_map_streamk(
|
||||
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
|
||||
block_2_ctile_map_streamk(M_,
|
||||
N_,
|
||||
AK0Number * CalculateKPadded(K_, 1),
|
||||
Grid_size_,
|
||||
Streamk_sel_,
|
||||
reduction_strategy_)
|
||||
|
||||
{
|
||||
}
|
||||
@@ -1267,11 +1288,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel);
|
||||
problem.Streamk_sel,
|
||||
problem.reduction_strategy);
|
||||
uint32_t iter_start, iter_end;
|
||||
bool is_sk_block, is_dp_block, is_reduction_block;
|
||||
index_t num_k_block_main_loop;
|
||||
@@ -1286,6 +1309,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
|
||||
reinterpret_cast<char*>(p_workspace) +
|
||||
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
|
||||
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
@@ -1301,8 +1325,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
|
||||
num_k_block_main_loop = iter_end - iter_start;
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
is_reduction_block = static_cast<uint32_t>(block_idx) >=
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
@@ -1890,8 +1913,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else if(is_sk_block)
|
||||
{
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
@@ -1903,8 +1925,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
else if(problem.reduction_strategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// constexpr offset
|
||||
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
|
||||
@@ -1936,8 +1958,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(is_sk_block)
|
||||
{
|
||||
@@ -1952,8 +1973,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
block_acc_offset -= MPerBlock * NPerBlock;
|
||||
}
|
||||
@@ -2008,7 +2028,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
problem.N,
|
||||
AK0Number * problem.KPadded,
|
||||
problem.Grid_size,
|
||||
problem.Streamk_sel);
|
||||
problem.Streamk_sel,
|
||||
problem.reduction_strategy);
|
||||
for(auto block_idx = get_block_1d_id();
|
||||
block_idx < block_2_ctile_map_streamk.get_grid_dims();
|
||||
block_idx += gridDim.x)
|
||||
@@ -2027,8 +2048,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
reinterpret_cast<char*>(p_workspace) +
|
||||
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
|
||||
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
is_reduction_block = static_cast<uint32_t>(block_idx) >=
|
||||
block_2_ctile_map_streamk.reduction_start_block_idx;
|
||||
@@ -2644,8 +2664,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
}
|
||||
else if(is_sk_block)
|
||||
{
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Atomic)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global
|
||||
@@ -2657,8 +2676,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
}
|
||||
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
else if(problem.reduction_strategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// constexpr offset
|
||||
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
|
||||
@@ -2693,16 +2712,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
iter_end -= current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
block_acc_offset -= MPerBlock * NPerBlock;
|
||||
}
|
||||
// make sure next loop LDS is ready for use
|
||||
block_sync_lds();
|
||||
}
|
||||
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
if(problem.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(is_sk_block)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user